From 41189a5b139fe1c2d9e44660c71ed39a34a4aa88 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 14 Jan 2026 11:58:59 +0100 Subject: [PATCH 01/68] init --- pyproject.toml | 6 ++++ src/squidpy/__init__.py | 3 +- src/squidpy/_utils.py | 36 ++++++++++++++++++++++ src/squidpy/settings/__init__.py | 7 +++++ src/squidpy/settings/_settings.py | 51 +++++++++++++++++++++++++++++++ 5 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 src/squidpy/settings/__init__.py create mode 100644 src/squidpy/settings/_settings.py diff --git a/pyproject.toml b/pyproject.toml index d682d786a..03c21c0a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,12 @@ optional-dependencies.leiden = [ "leidenalg", "spatialleiden>=0.4", ] +optional-dependencies.gpu-cuda12 = [ + "rapids-singlecell[rapids12]", +] +optional-dependencies.gpu-cuda11 = [ + "rapids-singlecell[rapids11]", +] optional-dependencies.test = [ "coverage[toml]>=7", "pytest>=7", diff --git a/src/squidpy/__init__.py b/src/squidpy/__init__.py index 85b250d82..7313b6f30 100644 --- a/src/squidpy/__init__.py +++ b/src/squidpy/__init__.py @@ -4,6 +4,7 @@ from importlib.metadata import PackageMetadata from squidpy import datasets, experimental, gr, im, pl, read, tl +from squidpy.settings import settings try: md: PackageMetadata = metadata.metadata(__name__) @@ -15,4 +16,4 @@ del metadata, md -__all__ = ["datasets", "experimental", "gr", "im", "pl", "read", "tl"] +__all__ = ["datasets", "experimental", "gr", "im", "pl", "read", "tl", "settings"] diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 736c88172..0ffb659f0 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -18,6 +18,7 @@ import numpy as np import xarray as xr from spatialdata.models import Image2DModel, Labels2DModel +from squidpy.settings import DeviceType __all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"] @@ -387,3 +388,38 @@ def _ensure_dim_order(img_da: xr.DataArray, order: Literal["cyx", "yxc"] = "yxc" img_da = img_da.expand_dims({"c": [0]}) # After possible expand, just transpose to target return img_da.transpose(*tuple(order)) + + +def resolve_device_arg(device: DeviceType | None) -> Literal["cpu", "gpu"]: + """ + Resolve per-call device argument to actual backend. + + Parameters + ---------- + device + Per-call device setting. None uses ``settings.device``. + + Returns + ------- + Literal["cpu", "gpu"] + The resolved backend to use. + + Raises + ------ + RuntimeError + If GPU is requested but rapids-singlecell is not installed. + """ + from squidpy.settings import settings + + if device is None: + device = settings.device + if device == "cpu": + return "cpu" + if device == "gpu": + if not settings.gpu_available(): + raise RuntimeError( + "GPU unavailable. Install with: pip install squidpy[gpu-cuda12] or squidpy[gpu-cuda11]" + ) + return "gpu" + # if device == "auto" + return "gpu" if settings.gpu_available() else "cpu" diff --git a/src/squidpy/settings/__init__.py b/src/squidpy/settings/__init__.py new file mode 100644 index 000000000..3491e01fd --- /dev/null +++ b/src/squidpy/settings/__init__.py @@ -0,0 +1,7 @@ +"""Squidpy settings and configuration.""" + +from __future__ import annotations + +from squidpy.settings._settings import DeviceType, settings + +__all__ = ["settings", "DeviceType"] diff --git a/src/squidpy/settings/_settings.py b/src/squidpy/settings/_settings.py new file mode 100644 index 000000000..d807f0714 --- /dev/null +++ b/src/squidpy/settings/_settings.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from contextvars import ContextVar +from typing import Literal, get_args + +__all__ = ["settings", "DeviceType"] + +DeviceType = Literal["auto", "cpu", "gpu"] + +_device_var: ContextVar[DeviceType] = ContextVar("device", default="auto") + + +class _SqSettings: + """Global settings for squidpy.""" + + @property + def device(self) -> DeviceType: + """Current compute device setting.""" + return _device_var.get() + + @device.setter + def device(self, value: DeviceType) -> None: + valid = get_args(DeviceType) + if value not in valid: + raise ValueError(f"Invalid device {value!r}. Must be one of: {valid}") + if value == "gpu" and not self.gpu_available(): + raise RuntimeError( + "Cannot set device='gpu': rapids-singlecell not installed. " + "Install with: pip install squidpy[gpu-cuda12] or squidpy[gpu-cuda11]" + ) + _device_var.set(value) + + @staticmethod + def gpu_available() -> bool: + """ + Check if GPU acceleration is available. + + Returns + ------- + bool + True if rapids-singlecell is installed and importable. + """ + try: + import rapids_singlecell # noqa: F401 + + return True + except ImportError: + return False + + +settings = _SqSettings() From e058886307599dbee745a94ee9de7268a71c9d1c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 11:07:01 +0000 Subject: [PATCH 02/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyproject.toml | 10 +++++----- src/squidpy/_utils.py | 5 ++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 03c21c0a9..13288e9e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,15 +102,15 @@ optional-dependencies.docs = [ "sphinxcontrib-bibtex>=2.3", "sphinxcontrib-spelling>=7.6.2", ] -optional-dependencies.leiden = [ - "leidenalg", - "spatialleiden>=0.4", +optional-dependencies.gpu-cuda11 = [ + "rapids-singlecell[rapids11]", ] optional-dependencies.gpu-cuda12 = [ "rapids-singlecell[rapids12]", ] -optional-dependencies.gpu-cuda11 = [ - "rapids-singlecell[rapids11]", +optional-dependencies.leiden = [ + "leidenalg", + "spatialleiden>=0.4", ] optional-dependencies.test = [ "coverage[toml]>=7", diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 0ffb659f0..e00a01a15 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -18,6 +18,7 @@ import numpy as np import xarray as xr from spatialdata.models import Image2DModel, Labels2DModel + from squidpy.settings import DeviceType __all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"] @@ -417,9 +418,7 @@ def resolve_device_arg(device: DeviceType | None) -> Literal["cpu", "gpu"]: return "cpu" if device == "gpu": if not settings.gpu_available(): - raise RuntimeError( - "GPU unavailable. Install with: pip install squidpy[gpu-cuda12] or squidpy[gpu-cuda11]" - ) + raise RuntimeError("GPU unavailable. Install with: pip install squidpy[gpu-cuda12] or squidpy[gpu-cuda11]") return "gpu" # if device == "auto" return "gpu" if settings.gpu_available() else "cpu" From a0bfcd479a8fb7d37e0f6731f25acb74c230309a Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 14 Jan 2026 13:00:45 +0100 Subject: [PATCH 03/68] add cooccurance demo --- src/squidpy/gr/_ppatterns.py | 22 +++++++++++++ tests/test_gpu.py | 64 ++++++++++++++++++++++++++++++++++++ tests/test_settings.py | 35 ++++++++++++++++++++ 3 files changed, 121 insertions(+) create mode 100644 tests/test_gpu.py create mode 100644 tests/test_settings.py diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 292c75994..957d31b12 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -352,6 +352,7 @@ def co_occurrence( n_jobs: int | None = None, backend: str = "loky", show_progress_bar: bool = True, + device: Literal["cpu", "gpu"] | None = None, ) -> tuple[NDArrayA, NDArrayA] | None: """ Compute co-occurrence probability of clusters. @@ -369,6 +370,9 @@ def co_occurrence( Number of splits in which to divide the spatial coordinates in :attr:`anndata.AnnData.obsm` ``['{spatial_key}']``. %(parallelize)s + device + Device to use for computation. If ``None``, uses :attr:`squidpy.settings.device`. + Set to ``"gpu"`` to use rapids-singlecell GPU acceleration. Returns ------- @@ -381,6 +385,24 @@ def co_occurrence( - :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds computed at ``interval``. """ + from squidpy._utils import resolve_device_arg + + effective_device = resolve_device_arg(device) + + if effective_device == "gpu": + from rapids_singlecell.squidpy import co_occurrence as rsc_co_occurrence + + return rsc_co_occurrence( + adata, + cluster_key=cluster_key, + spatial_key=spatial_key, + interval=interval, + copy=copy, + n_splits=n_splits, + n_jobs=n_jobs, + backend=backend, + show_progress_bar=show_progress_bar, + ) if isinstance(adata, SpatialData): adata = adata.table diff --git a/tests/test_gpu.py b/tests/test_gpu.py new file mode 100644 index 000000000..906e6eb87 --- /dev/null +++ b/tests/test_gpu.py @@ -0,0 +1,64 @@ +"""Tests for GPU functionality (skipped in CI without GPU).""" + +from __future__ import annotations + +import pytest + +from squidpy.settings import settings + + +# Skip all tests in this module if GPU is not available +pytestmark = pytest.mark.skipif( + not settings.gpu_available(), + reason="GPU tests require rapids-singlecell to be installed", +) + + +class TestGPUCoOccurrence: + """Test GPU-accelerated co_occurrence function.""" + + def test_co_occurrence_gpu(self, adata): + """Test co_occurrence with GPU device.""" + import squidpy as sq + + # Run with explicit GPU device + result = sq.gr.co_occurrence( + adata, + cluster_key="leiden", + copy=True, + device="gpu", + ) + + assert result is not None + arr, interval = result + assert arr.ndim == 3 + assert arr.shape[1] == arr.shape[0] == adata.obs["leiden"].unique().shape[0] + + def test_co_occurrence_gpu_vs_cpu(self, adata): + """Test that GPU and CPU results are approximately equal.""" + import numpy as np + + import squidpy as sq + + # Run on CPU + cpu_result = sq.gr.co_occurrence( + adata, + cluster_key="leiden", + copy=True, + device="cpu", + ) + + # Run on GPU + gpu_result = sq.gr.co_occurrence( + adata, + cluster_key="leiden", + copy=True, + device="gpu", + ) + + cpu_arr, cpu_interval = cpu_result + gpu_arr, gpu_interval = gpu_result + + # Results should be close (allow for floating point differences) + np.testing.assert_allclose(cpu_interval, gpu_interval, rtol=1e-5) + np.testing.assert_allclose(cpu_arr, gpu_arr, rtol=1e-5) diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 000000000..4d20a7cc9 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,35 @@ +"""Tests for squidpy.settings module.""" + +from __future__ import annotations + +import pytest + +from squidpy.settings import DeviceType, settings + + +class TestSettings: + """Test the settings module.""" + + def test_default_device(self): + """Test that default device is 'auto'.""" + # Reset to default + settings.device = "auto" + assert settings.device == "auto" + + def test_set_device_cpu(self): + """Test setting device to 'cpu'.""" + settings.device = "cpu" + assert settings.device == "cpu" + settings.device = "auto" # reset + + def test_set_device_invalid(self): + """Test that invalid device raises ValueError.""" + with pytest.raises(ValueError, match="Invalid device"): + settings.device = "invalid" + + def test_set_device_gpu_without_rsc(self): + """Test that setting device to 'gpu' without rapids-singlecell raises RuntimeError.""" + # This will fail if rapids-singlecell is not installed + if not settings.gpu_available(): + with pytest.raises(RuntimeError, match="rapids-singlecell not installed"): + settings.device = "gpu" From 5c67afd3ec85e3114c9ae9eba7dc2ebcf4a5bea3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 12:01:39 +0000 Subject: [PATCH 04/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_gpu.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_gpu.py b/tests/test_gpu.py index 906e6eb87..a8b3d2451 100644 --- a/tests/test_gpu.py +++ b/tests/test_gpu.py @@ -6,7 +6,6 @@ from squidpy.settings import settings - # Skip all tests in this module if GPU is not available pytestmark = pytest.mark.skipif( not settings.gpu_available(), @@ -28,7 +27,7 @@ def test_co_occurrence_gpu(self, adata): copy=True, device="gpu", ) - + assert result is not None arr, interval = result assert arr.ndim == 3 @@ -47,7 +46,7 @@ def test_co_occurrence_gpu_vs_cpu(self, adata): copy=True, device="cpu", ) - + # Run on GPU gpu_result = sq.gr.co_occurrence( adata, @@ -55,10 +54,10 @@ def test_co_occurrence_gpu_vs_cpu(self, adata): copy=True, device="gpu", ) - + cpu_arr, cpu_interval = cpu_result gpu_arr, gpu_interval = gpu_result - + # Results should be close (allow for floating point differences) np.testing.assert_allclose(cpu_interval, gpu_interval, rtol=1e-5) np.testing.assert_allclose(cpu_arr, gpu_arr, rtol=1e-5) From 73c1b15d1745b2e17b7286d62dc9cb12d1279862 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 14 Jan 2026 13:04:25 +0100 Subject: [PATCH 05/68] move import location --- src/squidpy/gr/_ppatterns.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 957d31b12..3c9b025bc 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -23,7 +23,7 @@ from squidpy._constants._constants import SpatialAutocorr from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize +from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize, resolve_device_arg from squidpy.gr._utils import ( _assert_categorical_obs, _assert_connectivity_key, @@ -385,8 +385,6 @@ def co_occurrence( - :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds computed at ``interval``. """ - from squidpy._utils import resolve_device_arg - effective_device = resolve_device_arg(device) if effective_device == "gpu": From 2474ac576759863430ec8fa11d703c3abb3b1aa6 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 14 Jan 2026 13:06:47 +0100 Subject: [PATCH 06/68] remove unused DeviceType --- tests/test_settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index 4d20a7cc9..db1cfd9f9 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -4,7 +4,7 @@ import pytest -from squidpy.settings import DeviceType, settings +from squidpy.settings import settings class TestSettings: From 464ae3c7bd652b8b8d0deeb2343cfe32cc7e6f97 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 26 Jan 2026 10:05:40 +0100 Subject: [PATCH 07/68] save wip to test in the clusters --- src/squidpy/_utils.py | 77 ++++++++++++++++++++++++++++++ src/squidpy/gr/_ppatterns.py | 20 +------- tests/test_settings.py | 92 ++++++++++++++++++++++++++++++++++++ 3 files changed, 171 insertions(+), 18 deletions(-) diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index e00a01a15..c9747484d 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -422,3 +422,80 @@ def resolve_device_arg(device: DeviceType | None) -> Literal["cpu", "gpu"]: return "gpu" # if device == "auto" return "gpu" if settings.gpu_available() else "cpu" + + +from typing import TypeVar + +F = TypeVar("F", bound=Callable[..., Any]) + + +def gpu_dispatch(rapids_module: str, rapids_func_name: str | None = None) -> Callable[[F], F]: + """ + Decorator to dispatch to rapids-singlecell GPU implementation when device="gpu". + + Automatically: + 1. Resolves effective device (from arg or settings) + 2. If GPU: imports rapids function, filters kwargs to match its signature, calls it + 3. If CPU: proceeds with the original squidpy implementation + + Parameters + ---------- + rapids_module + Module path, e.g. "rapids_singlecell.squidpy" or "rapids_singlecell.gr" + rapids_func_name + Function name in the rapids module. If None, uses the decorated function's name. + + Returns + ------- + Callable + Decorated function that dispatches to GPU or CPU implementation. + + Examples + -------- + >>> @gpu_dispatch("rapids_singlecell.squidpy") + ... def co_occurrence(adata, cluster_key, *, device=None, n_jobs=None, ...): + ... # CPU implementation + ... ... + """ + + def decorator(func: F) -> F: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + import importlib + + # Bind arguments to get a complete view of all parameters + sig = inspect.signature(func) + try: + bound = sig.bind(*args, **kwargs) + except TypeError: + # If binding fails, fall back to original function to get proper error + return func(*args, **kwargs) + + bound.apply_defaults() + all_args = dict(bound.arguments) + + # Extract and resolve device + device = all_args.pop("device", None) + effective_device = resolve_device_arg(device) + + if effective_device == "gpu": + # Import rapids function + module = importlib.import_module(rapids_module) + func_name = rapids_func_name if rapids_func_name is not None else func.__name__ + rapids_func = getattr(module, func_name) + + # Get rapids function's accepted parameters + rapids_sig = inspect.signature(rapids_func) + rapids_params = set(rapids_sig.parameters.keys()) + + # Filter to only parameters accepted by rapids function + filtered_args = {k: v for k, v in all_args.items() if k in rapids_params} + + return rapids_func(**filtered_args) + + # CPU path: call original function without device + return func(**all_args) + + return wrapper # type: ignore[return-value] + + return decorator diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 3c9b025bc..0a040e047 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -23,7 +23,7 @@ from squidpy._constants._constants import SpatialAutocorr from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize, resolve_device_arg +from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, gpu_dispatch, parallelize from squidpy.gr._utils import ( _assert_categorical_obs, _assert_connectivity_key, @@ -342,6 +342,7 @@ def _co_occurrence_helper(v_x: NDArrayA, v_y: NDArrayA, v_radium: NDArrayA, labs @d.dedent +@gpu_dispatch("rapids_singlecell.squidpy") def co_occurrence( adata: AnnData | SpatialData, cluster_key: str, @@ -385,23 +386,6 @@ def co_occurrence( - :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds computed at ``interval``. """ - effective_device = resolve_device_arg(device) - - if effective_device == "gpu": - from rapids_singlecell.squidpy import co_occurrence as rsc_co_occurrence - - return rsc_co_occurrence( - adata, - cluster_key=cluster_key, - spatial_key=spatial_key, - interval=interval, - copy=copy, - n_splits=n_splits, - n_jobs=n_jobs, - backend=backend, - show_progress_bar=show_progress_bar, - ) - if isinstance(adata, SpatialData): adata = adata.table _assert_categorical_obs(adata, key=cluster_key) diff --git a/tests/test_settings.py b/tests/test_settings.py index db1cfd9f9..f0b2e2657 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -2,8 +2,11 @@ from __future__ import annotations +from unittest.mock import MagicMock, patch + import pytest +from squidpy._utils import gpu_dispatch from squidpy.settings import settings @@ -33,3 +36,92 @@ def test_set_device_gpu_without_rsc(self): if not settings.gpu_available(): with pytest.raises(RuntimeError, match="rapids-singlecell not installed"): settings.device = "gpu" + + +class TestGpuDispatch: + """Test the gpu_dispatch decorator.""" + + def test_cpu_path_calls_original(self): + """Test that CPU device calls the original function.""" + original_called = [] + + @gpu_dispatch("fake_rapids_module") + def my_func(x, y, *, n_jobs=1, device=None): + original_called.append((x, y, n_jobs)) + return x + y + + result = my_func(1, 2, device="cpu") + assert result == 3 + assert original_called == [(1, 2, 1)] + + def test_cpu_path_with_auto_device_no_gpu(self): + """Test that auto device falls back to CPU when GPU unavailable.""" + original_called = [] + + @gpu_dispatch("fake_rapids_module") + def my_func(x, device=None): + original_called.append(x) + return x * 2 + + # With auto and no GPU available, should call original + if not settings.gpu_available(): + result = my_func(5, device="auto") + assert result == 10 + assert original_called == [5] + + def test_gpu_path_filters_parameters(self): + """Test that GPU dispatch filters out parameters not in rapids signature.""" + mock_rapids_func = MagicMock(return_value="gpu_result") + + # Create a mock module + mock_module = MagicMock() + mock_module.my_func = mock_rapids_func + + @gpu_dispatch("mock_rapids") + def my_func(adata, cluster_key, *, n_jobs=1, backend="loky", device=None): + return "cpu_result" + + with patch("importlib.import_module", return_value=mock_module): + with patch("squidpy._utils.resolve_device_arg", return_value="gpu"): + # Mock the rapids function signature to only accept adata and cluster_key + import inspect + + mock_sig = inspect.signature(lambda adata, cluster_key: None) + with patch("inspect.signature", side_effect=lambda f: mock_sig if f == mock_rapids_func else inspect.signature(f)): + result = my_func("adata_obj", "leiden", n_jobs=4, backend="threading", device="gpu") + + assert result == "gpu_result" + # Should only be called with adata and cluster_key, not n_jobs or backend + mock_rapids_func.assert_called_once_with(adata="adata_obj", cluster_key="leiden") + + def test_preserves_function_metadata(self): + """Test that the decorator preserves function name and docstring.""" + + @gpu_dispatch("fake_module") + def documented_func(x, device=None): + """This is the docstring.""" + return x + + assert documented_func.__name__ == "documented_func" + assert documented_func.__doc__ == """This is the docstring.""" + + def test_custom_rapids_func_name(self): + """Test using a custom rapids function name.""" + mock_rapids_func = MagicMock(return_value="rapids_result") + mock_module = MagicMock() + mock_module.different_name = mock_rapids_func + + @gpu_dispatch("mock_rapids", rapids_func_name="different_name") + def my_func(x, device=None): + return "cpu_result" + + with patch("importlib.import_module", return_value=mock_module): + with patch("squidpy._utils.resolve_device_arg", return_value="gpu"): + import inspect + + mock_sig = inspect.signature(lambda x: None) + with patch("inspect.signature", side_effect=lambda f: mock_sig if f == mock_rapids_func else inspect.signature(f)): + result = my_func(42, device="gpu") + + assert result == "rapids_result" + mock_rapids_func.assert_called_once_with(x=42) From 037b6a8c3271775152b228ed6bfa3785265fdf57 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jan 2026 09:05:55 +0000 Subject: [PATCH 08/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_settings.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index f0b2e2657..fe006236f 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -87,7 +87,10 @@ def my_func(adata, cluster_key, *, n_jobs=1, backend="loky", device=None): import inspect mock_sig = inspect.signature(lambda adata, cluster_key: None) - with patch("inspect.signature", side_effect=lambda f: mock_sig if f == mock_rapids_func else inspect.signature(f)): + with patch( + "inspect.signature", + side_effect=lambda f: mock_sig if f == mock_rapids_func else inspect.signature(f), + ): result = my_func("adata_obj", "leiden", n_jobs=4, backend="threading", device="gpu") assert result == "gpu_result" @@ -120,7 +123,10 @@ def my_func(x, device=None): import inspect mock_sig = inspect.signature(lambda x: None) - with patch("inspect.signature", side_effect=lambda f: mock_sig if f == mock_rapids_func else inspect.signature(f)): + with patch( + "inspect.signature", + side_effect=lambda f: mock_sig if f == mock_rapids_func else inspect.signature(f), + ): result = my_func(42, device="gpu") assert result == "rapids_result" From 0820a0e578141bf82ab38420b94d9babc9d439be Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 26 Jan 2026 11:33:31 +0100 Subject: [PATCH 09/68] adjust deps but need prerelease of spatialdata for unpining dask --- pyproject.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 04410c3aa..73e51344c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,8 +46,8 @@ dependencies = [ "aiohttp>=3.8.1", "anndata>=0.9", "cycler>=0.11", - "dask[array]>=2021.2,<=2024.11.2", - "dask-image>=0.5", + "dask[array]>=2021.2", + "dask-image>=2024.5", "docrep>=0.3.1", "fast-array-utils", "fsspec>=2021.11", @@ -104,10 +104,10 @@ optional-dependencies.docs = [ "sphinxcontrib-spelling>=7.6.2", ] optional-dependencies.gpu-cuda11 = [ - "rapids-singlecell[rapids11]", + "rapids-singlecell[rapids11]>=0.13.5", ] optional-dependencies.gpu-cuda12 = [ - "rapids-singlecell[rapids12]", + "rapids-singlecell[rapids12]>=0.13.5", ] optional-dependencies.leiden = [ "leidenalg", From 70af5c70d45efe6c52680fcb6a330822858d02c2 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 26 Jan 2026 11:35:06 +0100 Subject: [PATCH 10/68] fix the set_default_colors_for_categorical_obs --- src/squidpy/_compat.py | 2 +- src/squidpy/_utils.py | 4 ++-- src/squidpy/gr/_ppatterns.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/squidpy/_compat.py b/src/squidpy/_compat.py index 66fa5e605..e56040310 100644 --- a/src/squidpy/_compat.py +++ b/src/squidpy/_compat.py @@ -20,7 +20,7 @@ ] # See https://github.com/scverse/squidpy/issues/1061 for more details -_SET_DEFAULT_COLORS_FOR_CATEGORICAL_OBS_CHANGED = Version(version("scanpy")) >= Version("0.12.0rc1") +_SET_DEFAULT_COLORS_FOR_CATEGORICAL_OBS_CHANGED = Version("0.12.0rc1") <= Version(version("scanpy")) < Version("0.12.0") if _SET_DEFAULT_COLORS_FOR_CATEGORICAL_OBS_CHANGED: from scanpy.plotting._utils import _set_default_colors_for_categorical_obs as set_default_colors_for_categorical_obs diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index c9747484d..80c57aa3f 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -441,7 +441,7 @@ def gpu_dispatch(rapids_module: str, rapids_func_name: str | None = None) -> Cal Parameters ---------- rapids_module - Module path, e.g. "rapids_singlecell.squidpy" or "rapids_singlecell.gr" + Module path, e.g. "rapids_singlecell.squidpy_gpu" or "rapids_singlecell.gr" rapids_func_name Function name in the rapids module. If None, uses the decorated function's name. @@ -452,7 +452,7 @@ def gpu_dispatch(rapids_module: str, rapids_func_name: str | None = None) -> Cal Examples -------- - >>> @gpu_dispatch("rapids_singlecell.squidpy") + >>> @gpu_dispatch("rapids_singlecell.squidpy_gpu") ... def co_occurrence(adata, cluster_key, *, device=None, n_jobs=None, ...): ... # CPU implementation ... ... diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 0a040e047..c9e4b052c 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -342,7 +342,7 @@ def _co_occurrence_helper(v_x: NDArrayA, v_y: NDArrayA, v_radium: NDArrayA, labs @d.dedent -@gpu_dispatch("rapids_singlecell.squidpy") +@gpu_dispatch("rapids_singlecell.squidpy_gpu") def co_occurrence( adata: AnnData | SpatialData, cluster_key: str, From 74c70d9bb1f1502167fb636bd89c3e06f498bf4a Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 26 Jan 2026 11:43:25 +0100 Subject: [PATCH 11/68] add ligrec and cooccurence --- src/squidpy/gr/_ligrec.py | 7 +- src/squidpy/gr/_ppatterns.py | 5 ++ tests/test_gpu.py | 142 +++++++++++++++++++++++++++++++++++ 3 files changed, 153 insertions(+), 1 deletion(-) diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index f369759b1..22b71ab19 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -20,7 +20,7 @@ from squidpy._constants._constants import ComplexPolicy, CorrAxis from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize +from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, gpu_dispatch, parallelize from squidpy.gr._utils import ( _assert_categorical_obs, _assert_positive, @@ -633,6 +633,7 @@ def prepare( @d.dedent +@gpu_dispatch("rapids_singlecell.squidpy_gpu") def ligrec( adata: AnnData | SpatialData, cluster_key: str, @@ -645,6 +646,7 @@ def ligrec( copy: bool = False, key_added: str | None = None, gene_symbols: str | None = None, + device: Literal["cpu", "gpu"] | None = None, **kwargs: Any, ) -> Mapping[str, pd.DataFrame] | None: """ @@ -657,6 +659,9 @@ def ligrec( %(PT_test.parameters)s gene_symbols Key in :attr:`anndata.AnnData.var` to use instead of :attr:`anndata.AnnData.var_names`. + device + Device to use for computation. If ``None``, uses :attr:`squidpy.settings.device`. + Set to ``"gpu"`` to use rapids-singlecell GPU acceleration. Returns ------- diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index c9e4b052c..603997b43 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -45,6 +45,7 @@ @d.dedent @inject_docs(key=Key.obsp.spatial_conn(), sp=SpatialAutocorr) +@gpu_dispatch("rapids_singlecell.squidpy_gpu") def spatial_autocorr( adata: AnnData | SpatialData, connectivity_key: str = Key.obsp.spatial_conn(), @@ -62,6 +63,7 @@ def spatial_autocorr( n_jobs: int | None = None, backend: str = "loky", show_progress_bar: bool = True, + device: Literal["cpu", "gpu"] | None = None, ) -> pd.DataFrame | None: """ Calculate Global Autocorrelation Statistic (Moran’s I or Geary's C). @@ -107,6 +109,9 @@ def spatial_autocorr( %(seed)s %(copy)s %(parallelize)s + device + Device to use for computation. If ``None``, uses :attr:`squidpy.settings.device`. + Set to ``"gpu"`` to use rapids-singlecell GPU acceleration. Returns ------- diff --git a/tests/test_gpu.py b/tests/test_gpu.py index a8b3d2451..19a113e8f 100644 --- a/tests/test_gpu.py +++ b/tests/test_gpu.py @@ -61,3 +61,145 @@ def test_co_occurrence_gpu_vs_cpu(self, adata): # Results should be close (allow for floating point differences) np.testing.assert_allclose(cpu_interval, gpu_interval, rtol=1e-5) np.testing.assert_allclose(cpu_arr, gpu_arr, rtol=1e-5) + + +class TestGPUSpatialAutocorr: + """Test GPU-accelerated spatial_autocorr function.""" + + def test_spatial_autocorr_gpu(self, adata): + """Test spatial_autocorr with GPU device.""" + import squidpy as sq + + # Ensure spatial neighbors are computed + sq.gr.spatial_neighbors(adata) + + # Run with explicit GPU device + result = sq.gr.spatial_autocorr( + adata, + mode="moran", + copy=True, + device="gpu", + ) + + assert result is not None + assert "I" in result.columns + assert "pval_norm" in result.columns + + def test_spatial_autocorr_gpu_vs_cpu(self, adata): + """Test that GPU and CPU results are approximately equal.""" + import numpy as np + + import squidpy as sq + + # Ensure spatial neighbors are computed + sq.gr.spatial_neighbors(adata) + + # Run on CPU + cpu_result = sq.gr.spatial_autocorr( + adata, + mode="moran", + copy=True, + device="cpu", + ) + + # Run on GPU + gpu_result = sq.gr.spatial_autocorr( + adata, + mode="moran", + copy=True, + device="gpu", + ) + + # Results should be close (allow for floating point differences) + # Use equal_nan=True since some genes may have NaN values, and relax rtol for float32/float64 differences + np.testing.assert_allclose(cpu_result["I"].values, gpu_result["I"].values, rtol=1e-3, equal_nan=True) + + +class TestGPULigrec: + """Test GPU-accelerated ligrec function.""" + + def test_ligrec_gpu(self, adata): + """Test ligrec with GPU device.""" + import squidpy as sq + + # Run with explicit GPU device + result = sq.gr.ligrec( + adata, + cluster_key="leiden", + copy=True, + device="gpu", + ) + + assert result is not None + assert "means" in result + assert "pvalues" in result + + +class TestGPUSettingsOptIn: + """Test settings-based GPU opt-in functionality.""" + + def test_settings_device_gpu(self, adata): + """Test that setting device='gpu' globally uses GPU for all functions.""" + import squidpy as sq + from squidpy.settings import settings + + # Save original setting + original_device = settings.device + + try: + # Opt-in to GPU globally + settings.device = "gpu" + + # Run without explicit device - should use GPU + result = sq.gr.co_occurrence( + adata, + cluster_key="leiden", + copy=True, + ) + + assert result is not None + finally: + # Restore original setting + settings.device = original_device + + def test_settings_device_auto(self, adata): + """Test that device='auto' uses GPU when available.""" + import squidpy as sq + from squidpy.settings import settings + + # auto is the default, and should use GPU when available + assert settings.device == "auto" + + # Run without explicit device - should automatically use GPU + result = sq.gr.co_occurrence( + adata, + cluster_key="leiden", + copy=True, + ) + + assert result is not None + + def test_explicit_device_overrides_settings(self, adata): + """Test that explicit device parameter overrides global settings.""" + import squidpy as sq + from squidpy.settings import settings + + # Save original setting + original_device = settings.device + + try: + # Set global to CPU + settings.device = "cpu" + + # But explicitly request GPU - should use GPU + result = sq.gr.co_occurrence( + adata, + cluster_key="leiden", + copy=True, + device="gpu", + ) + + assert result is not None + finally: + # Restore original setting + settings.device = original_device From a1d2d1076a6c16755a24906c0f76cc7d1eb43680 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 26 Jan 2026 11:52:39 +0100 Subject: [PATCH 12/68] reduce bloat code --- src/squidpy/gr/_ligrec.py | 3 +- src/squidpy/gr/_ppatterns.py | 6 +- tests/test_gpu.py | 158 +++-------------------------------- 3 files changed, 14 insertions(+), 153 deletions(-) diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index 22b71ab19..dabf8c5b1 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -660,8 +660,7 @@ def ligrec( gene_symbols Key in :attr:`anndata.AnnData.var` to use instead of :attr:`anndata.AnnData.var_names`. device - Device to use for computation. If ``None``, uses :attr:`squidpy.settings.device`. - Set to ``"gpu"`` to use rapids-singlecell GPU acceleration. + Device for computation: ``"cpu"``, ``"gpu"``, or ``None`` (use ``settings.device``). Returns ------- diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 603997b43..e0d408921 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -110,8 +110,7 @@ def spatial_autocorr( %(copy)s %(parallelize)s device - Device to use for computation. If ``None``, uses :attr:`squidpy.settings.device`. - Set to ``"gpu"`` to use rapids-singlecell GPU acceleration. + Device for computation: ``"cpu"``, ``"gpu"``, or ``None`` (use ``settings.device``). Returns ------- @@ -377,8 +376,7 @@ def co_occurrence( :attr:`anndata.AnnData.obsm` ``['{spatial_key}']``. %(parallelize)s device - Device to use for computation. If ``None``, uses :attr:`squidpy.settings.device`. - Set to ``"gpu"`` to use rapids-singlecell GPU acceleration. + Device for computation: ``"cpu"``, ``"gpu"``, or ``None`` (use ``settings.device``). Returns ------- diff --git a/tests/test_gpu.py b/tests/test_gpu.py index 19a113e8f..7afc866aa 100644 --- a/tests/test_gpu.py +++ b/tests/test_gpu.py @@ -2,8 +2,10 @@ from __future__ import annotations +import numpy as np import pytest +import squidpy as sq from squidpy.settings import settings # Skip all tests in this module if GPU is not available @@ -18,47 +20,18 @@ class TestGPUCoOccurrence: def test_co_occurrence_gpu(self, adata): """Test co_occurrence with GPU device.""" - import squidpy as sq - - # Run with explicit GPU device - result = sq.gr.co_occurrence( - adata, - cluster_key="leiden", - copy=True, - device="gpu", - ) + result = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True, device="gpu") assert result is not None arr, interval = result assert arr.ndim == 3 - assert arr.shape[1] == arr.shape[0] == adata.obs["leiden"].unique().shape[0] + assert arr.shape[1] == arr.shape[0] == adata.obs["leiden"].nunique() def test_co_occurrence_gpu_vs_cpu(self, adata): """Test that GPU and CPU results are approximately equal.""" - import numpy as np - - import squidpy as sq - - # Run on CPU - cpu_result = sq.gr.co_occurrence( - adata, - cluster_key="leiden", - copy=True, - device="cpu", - ) - - # Run on GPU - gpu_result = sq.gr.co_occurrence( - adata, - cluster_key="leiden", - copy=True, - device="gpu", - ) - - cpu_arr, cpu_interval = cpu_result - gpu_arr, gpu_interval = gpu_result - - # Results should be close (allow for floating point differences) + cpu_arr, cpu_interval = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True, device="cpu") + gpu_arr, gpu_interval = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True, device="gpu") + np.testing.assert_allclose(cpu_interval, gpu_interval, rtol=1e-5) np.testing.assert_allclose(cpu_arr, gpu_arr, rtol=1e-5) @@ -68,18 +41,8 @@ class TestGPUSpatialAutocorr: def test_spatial_autocorr_gpu(self, adata): """Test spatial_autocorr with GPU device.""" - import squidpy as sq - - # Ensure spatial neighbors are computed sq.gr.spatial_neighbors(adata) - - # Run with explicit GPU device - result = sq.gr.spatial_autocorr( - adata, - mode="moran", - copy=True, - device="gpu", - ) + result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True, device="gpu") assert result is not None assert "I" in result.columns @@ -87,31 +50,10 @@ def test_spatial_autocorr_gpu(self, adata): def test_spatial_autocorr_gpu_vs_cpu(self, adata): """Test that GPU and CPU results are approximately equal.""" - import numpy as np - - import squidpy as sq - - # Ensure spatial neighbors are computed sq.gr.spatial_neighbors(adata) + cpu_result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True, device="cpu") + gpu_result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True, device="gpu") - # Run on CPU - cpu_result = sq.gr.spatial_autocorr( - adata, - mode="moran", - copy=True, - device="cpu", - ) - - # Run on GPU - gpu_result = sq.gr.spatial_autocorr( - adata, - mode="moran", - copy=True, - device="gpu", - ) - - # Results should be close (allow for floating point differences) - # Use equal_nan=True since some genes may have NaN values, and relax rtol for float32/float64 differences np.testing.assert_allclose(cpu_result["I"].values, gpu_result["I"].values, rtol=1e-3, equal_nan=True) @@ -120,86 +62,8 @@ class TestGPULigrec: def test_ligrec_gpu(self, adata): """Test ligrec with GPU device.""" - import squidpy as sq - - # Run with explicit GPU device - result = sq.gr.ligrec( - adata, - cluster_key="leiden", - copy=True, - device="gpu", - ) + result = sq.gr.ligrec(adata, cluster_key="leiden", copy=True, device="gpu") assert result is not None assert "means" in result assert "pvalues" in result - - -class TestGPUSettingsOptIn: - """Test settings-based GPU opt-in functionality.""" - - def test_settings_device_gpu(self, adata): - """Test that setting device='gpu' globally uses GPU for all functions.""" - import squidpy as sq - from squidpy.settings import settings - - # Save original setting - original_device = settings.device - - try: - # Opt-in to GPU globally - settings.device = "gpu" - - # Run without explicit device - should use GPU - result = sq.gr.co_occurrence( - adata, - cluster_key="leiden", - copy=True, - ) - - assert result is not None - finally: - # Restore original setting - settings.device = original_device - - def test_settings_device_auto(self, adata): - """Test that device='auto' uses GPU when available.""" - import squidpy as sq - from squidpy.settings import settings - - # auto is the default, and should use GPU when available - assert settings.device == "auto" - - # Run without explicit device - should automatically use GPU - result = sq.gr.co_occurrence( - adata, - cluster_key="leiden", - copy=True, - ) - - assert result is not None - - def test_explicit_device_overrides_settings(self, adata): - """Test that explicit device parameter overrides global settings.""" - import squidpy as sq - from squidpy.settings import settings - - # Save original setting - original_device = settings.device - - try: - # Set global to CPU - settings.device = "cpu" - - # But explicitly request GPU - should use GPU - result = sq.gr.co_occurrence( - adata, - cluster_key="leiden", - copy=True, - device="gpu", - ) - - assert result is not None - finally: - # Restore original setting - settings.device = original_device From a331134a889be90ade1a39c18ef489b42eeed181 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 26 Jan 2026 15:28:05 +0100 Subject: [PATCH 13/68] updates for the new dispatch approach --- src/squidpy/_docs.py | 17 ++ src/squidpy/_utils.py | 110 ------------- src/squidpy/gr/_gpu.py | 250 ++++++++++++++++++++++++++++++ src/squidpy/gr/_ligrec.py | 8 +- src/squidpy/gr/_ppatterns.py | 22 +-- src/squidpy/settings/__init__.py | 7 +- src/squidpy/settings/_dispatch.py | 62 ++++++++ src/squidpy/settings/_settings.py | 32 ++-- tests/test_settings.py | 73 ++++----- 9 files changed, 384 insertions(+), 197 deletions(-) create mode 100644 src/squidpy/gr/_gpu.py create mode 100644 src/squidpy/settings/_dispatch.py diff --git a/src/squidpy/_docs.py b/src/squidpy/_docs.py index a4596df75..d931b6e19 100644 --- a/src/squidpy/_docs.py +++ b/src/squidpy/_docs.py @@ -119,6 +119,20 @@ def decorator2(obj: Any) -> Any: Parallelization backend to use. See :class:`joblib.Parallel` for available options. show_progress_bar Whether to show the progress bar or not.""" +_parallelize_device = """\ +n_jobs + Number of parallel jobs. Ignored when ``device='gpu'``. +backend + Parallelization backend. Ignored when ``device='gpu'``. +show_progress_bar + Whether to show the progress bar. Ignored when ``device='gpu'``.""" +_seed_device = """\ +seed + Random seed for reproducibility. Ignored when ``device='gpu'``.""" +_device = """\ +device + Device for computation: ``'cpu'``, ``'gpu'``, or ``None`` (use ``squidpy.settings.device``). + When ``'gpu'``, dispatches to :mod:`rapids_singlecell` for GPU-accelerated computation.""" _channels = """\ channels Channels for this feature is computed. If `None`, use all channels.""" @@ -379,6 +393,9 @@ def decorator2(obj: Any) -> Any: cat_plotting=_cat_plotting, plotting_returns=_plotting_returns, parallelize=_parallelize, + parallelize_device=_parallelize_device, + seed_device=_seed_device, + device=_device, channels=_channels, segment_kwargs=_segment_kwargs, ligrec_test_returns=_ligrec_test_returns, diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 80c57aa3f..24d5765de 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -19,8 +19,6 @@ import xarray as xr from spatialdata.models import Image2DModel, Labels2DModel -from squidpy.settings import DeviceType - __all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"] @@ -391,111 +389,3 @@ def _ensure_dim_order(img_da: xr.DataArray, order: Literal["cyx", "yxc"] = "yxc" return img_da.transpose(*tuple(order)) -def resolve_device_arg(device: DeviceType | None) -> Literal["cpu", "gpu"]: - """ - Resolve per-call device argument to actual backend. - - Parameters - ---------- - device - Per-call device setting. None uses ``settings.device``. - - Returns - ------- - Literal["cpu", "gpu"] - The resolved backend to use. - - Raises - ------ - RuntimeError - If GPU is requested but rapids-singlecell is not installed. - """ - from squidpy.settings import settings - - if device is None: - device = settings.device - if device == "cpu": - return "cpu" - if device == "gpu": - if not settings.gpu_available(): - raise RuntimeError("GPU unavailable. Install with: pip install squidpy[gpu-cuda12] or squidpy[gpu-cuda11]") - return "gpu" - # if device == "auto" - return "gpu" if settings.gpu_available() else "cpu" - - -from typing import TypeVar - -F = TypeVar("F", bound=Callable[..., Any]) - - -def gpu_dispatch(rapids_module: str, rapids_func_name: str | None = None) -> Callable[[F], F]: - """ - Decorator to dispatch to rapids-singlecell GPU implementation when device="gpu". - - Automatically: - 1. Resolves effective device (from arg or settings) - 2. If GPU: imports rapids function, filters kwargs to match its signature, calls it - 3. If CPU: proceeds with the original squidpy implementation - - Parameters - ---------- - rapids_module - Module path, e.g. "rapids_singlecell.squidpy_gpu" or "rapids_singlecell.gr" - rapids_func_name - Function name in the rapids module. If None, uses the decorated function's name. - - Returns - ------- - Callable - Decorated function that dispatches to GPU or CPU implementation. - - Examples - -------- - >>> @gpu_dispatch("rapids_singlecell.squidpy_gpu") - ... def co_occurrence(adata, cluster_key, *, device=None, n_jobs=None, ...): - ... # CPU implementation - ... ... - """ - - def decorator(func: F) -> F: - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - import importlib - - # Bind arguments to get a complete view of all parameters - sig = inspect.signature(func) - try: - bound = sig.bind(*args, **kwargs) - except TypeError: - # If binding fails, fall back to original function to get proper error - return func(*args, **kwargs) - - bound.apply_defaults() - all_args = dict(bound.arguments) - - # Extract and resolve device - device = all_args.pop("device", None) - effective_device = resolve_device_arg(device) - - if effective_device == "gpu": - # Import rapids function - module = importlib.import_module(rapids_module) - func_name = rapids_func_name if rapids_func_name is not None else func.__name__ - rapids_func = getattr(module, func_name) - - # Get rapids function's accepted parameters - rapids_sig = inspect.signature(rapids_func) - rapids_params = set(rapids_sig.parameters.keys()) - - # Filter to only parameters accepted by rapids function - filtered_args = {k: v for k, v in all_args.items() if k in rapids_params} - - return rapids_func(**filtered_args) - - # CPU path: call original function without device - return func(**all_args) - - return wrapper # type: ignore[return-value] - - return decorator diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py new file mode 100644 index 000000000..f1006d260 --- /dev/null +++ b/src/squidpy/gr/_gpu.py @@ -0,0 +1,250 @@ +"""GPU adapter functions for squidpy.gr functions. + +These stubs provide explicit parameter mapping between squidpy and rapids_singlecell, +ensuring compatibility and clear documentation of supported parameters. +""" + +from __future__ import annotations + +import warnings +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal, Sequence + +from squidpy._constants._pkg_constants import Key + +if TYPE_CHECKING: + from anndata import AnnData + from numpy.typing import NDArray + + +@dataclass +class GpuParamSpec: + """Specification for a parameter's GPU compatibility.""" + + default: Any + message: str | None = None + validator: Callable[[Any], str | None] | None = None + + +def _attr_validator(value: Any) -> str | None: + """Special validator for attr param - only warn if not 'X'.""" + if value == "X": + return None + return f"attr={value!r} is not supported on GPU, using attr='X'. Set device='cpu' to use other attributes." + + +# Common CPU-only param specs (reusable) +_PARALLELIZE: dict[str, GpuParamSpec] = { + "n_jobs": GpuParamSpec(None), + "backend": GpuParamSpec("loky"), + "show_progress_bar": GpuParamSpec(True), +} +_SEED: dict[str, GpuParamSpec] = {"seed": GpuParamSpec(None)} + +# Registry: {func_name: {"cpu_only": {...}, "gpu_only": {...}}} +GPU_PARAM_REGISTRY: dict[str, dict[str, dict[str, GpuParamSpec]]] = { + "spatial_autocorr": { + "cpu_only": { + "attr": GpuParamSpec("X", validator=_attr_validator), + **_SEED, + **_PARALLELIZE, + }, + "gpu_only": { + "use_sparse": GpuParamSpec(True), + }, + }, + "co_occurrence": { + "cpu_only": { + "n_splits": GpuParamSpec(None), + **_PARALLELIZE, + }, + "gpu_only": {}, + }, + "ligrec": { + "cpu_only": { + "clusters": GpuParamSpec(None), + "numba_parallel": GpuParamSpec(None), + "transmitter_params": GpuParamSpec(None), + "receiver_params": GpuParamSpec(None), + "interactions_params": GpuParamSpec(None), + "alpha": GpuParamSpec(0.05), + **_SEED, + **_PARALLELIZE, + }, + "gpu_only": {}, + }, +} + + +@dataclass +class CheckResult: + """Result of parameter compatibility check.""" + + ignored: dict[str, Any] = field(default_factory=dict) + warnings: list[str] = field(default_factory=list) + gpu_defaults: dict[str, Any] = field(default_factory=dict) + + +def check_gpu_params(func_name: str, **cpu_only_values: Any) -> CheckResult: + """Check CPU-only params against registry, warn about non-defaults, return GPU defaults. + + Parameters + ---------- + func_name + Name of the function in GPU_PARAM_REGISTRY. + **cpu_only_values + CPU-only parameter values to check. + """ + result = CheckResult() + registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) + + # Check CPU-only params + for name, spec in registry["cpu_only"].items(): + if name not in cpu_only_values: + continue + value = cpu_only_values[name] + + # Use custom validator if provided, else default behavior + if spec.validator: + msg = spec.validator(value) + elif value != spec.default: + msg = spec.message or f"{name}={value!r} is ignored on GPU." + else: + msg = None + + if msg: + msg = msg.format(name=name, value=value) + result.ignored[name] = value + result.warnings.append(msg) + warnings.warn(msg, UserWarning, stacklevel=3) + + # Collect GPU-only param defaults + for name, spec in registry["gpu_only"].items(): + result.gpu_defaults[name] = spec.default + + return result + + +def spatial_autocorr_gpu( + adata: AnnData, + connectivity_key: str = Key.obsp.spatial_conn(), + genes: str | int | Sequence[str] | Sequence[int] | None = None, + mode: Literal["moran", "geary"] = "moran", + transformation: bool = True, + n_perms: int | None = None, + two_tailed: bool = False, + corr_method: str | None = "fdr_bh", + layer: str | None = None, + use_raw: bool = False, + copy: bool = False, + # CPU-only params + attr: Literal["obs", "X", "obsm"] = "X", + seed: int | None = None, + n_jobs: int | None = None, + backend: str = "loky", + show_progress_bar: bool = True, +) -> Any: + """GPU adapter for spatial_autocorr via rapids_singlecell.""" + from rapids_singlecell.squidpy_gpu import spatial_autocorr as _spatial_autocorr_gpu + + check = check_gpu_params( + "spatial_autocorr", + attr=attr, seed=seed, n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, + ) + + return _spatial_autocorr_gpu( + adata=adata, + connectivity_key=connectivity_key, + genes=genes, + mode=mode, + transformation=transformation, + n_perms=n_perms, + two_tailed=two_tailed, + corr_method=corr_method, + layer=layer, + use_raw=use_raw, + copy=copy, + **check.gpu_defaults, + ) + + +def co_occurrence_gpu( + adata: AnnData, + cluster_key: str, + spatial_key: str = Key.obsm.spatial, + interval: int | NDArray[Any] = 50, + copy: bool = False, + # CPU-only params + n_splits: int | None = None, + n_jobs: int | None = None, + backend: str = "loky", + show_progress_bar: bool = True, +) -> Any: + """GPU adapter for co_occurrence via rapids_singlecell.""" + from rapids_singlecell.squidpy_gpu import co_occurrence as _co_occurrence_gpu + + check_gpu_params( + "co_occurrence", + n_splits=n_splits, n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, + ) + + return _co_occurrence_gpu( + adata=adata, + cluster_key=cluster_key, + spatial_key=spatial_key, + interval=interval, + copy=copy, + ) + + +def ligrec_gpu( + adata: AnnData, + cluster_key: str, + interactions: Any = None, + complex_policy: Literal["min", "all"] = "min", + threshold: float = 0.01, + corr_method: str | None = None, + corr_axis: Literal["interactions", "clusters"] = "clusters", + use_raw: bool = True, + copy: bool = False, + key_added: str | None = None, + gene_symbols: str | None = None, + n_perms: int = 1000, + # CPU-only params + clusters: Any = None, + seed: int | None = None, + numba_parallel: bool | None = None, + n_jobs: int | None = None, + backend: str = "loky", + show_progress_bar: bool = True, + transmitter_params: dict[str, Any] | None = None, + receiver_params: dict[str, Any] | None = None, + interactions_params: dict[str, Any] | None = None, + alpha: float = 0.05, +) -> Any: + """GPU adapter for ligrec via rapids_singlecell.""" + from rapids_singlecell.squidpy_gpu import ligrec as _ligrec_gpu + + check_gpu_params( + "ligrec", + clusters=clusters, seed=seed, numba_parallel=numba_parallel, + n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, + transmitter_params=transmitter_params, receiver_params=receiver_params, + interactions_params=interactions_params, alpha=alpha, + ) + + return _ligrec_gpu( + adata=adata, + cluster_key=cluster_key, + interactions=interactions, + complex_policy=complex_policy, + threshold=threshold, + corr_method=corr_method, + corr_axis=corr_axis, + use_raw=use_raw, + copy=copy, + key_added=key_added, + gene_symbols=gene_symbols, + n_perms=n_perms, + ) diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index dabf8c5b1..ed610fa20 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -20,7 +20,8 @@ from squidpy._constants._constants import ComplexPolicy, CorrAxis from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, gpu_dispatch, parallelize +from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize +from squidpy.settings import gpu_dispatch from squidpy.gr._utils import ( _assert_categorical_obs, _assert_positive, @@ -633,7 +634,7 @@ def prepare( @d.dedent -@gpu_dispatch("rapids_singlecell.squidpy_gpu") +@gpu_dispatch() def ligrec( adata: AnnData | SpatialData, cluster_key: str, @@ -659,8 +660,7 @@ def ligrec( %(PT_test.parameters)s gene_symbols Key in :attr:`anndata.AnnData.var` to use instead of :attr:`anndata.AnnData.var_names`. - device - Device for computation: ``"cpu"``, ``"gpu"``, or ``None`` (use ``settings.device``). + %(device)s Returns ------- diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index e0d408921..b05298be3 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -23,7 +23,8 @@ from squidpy._constants._constants import SpatialAutocorr from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, gpu_dispatch, parallelize +from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize +from squidpy.settings import gpu_dispatch from squidpy.gr._utils import ( _assert_categorical_obs, _assert_connectivity_key, @@ -45,7 +46,7 @@ @d.dedent @inject_docs(key=Key.obsp.spatial_conn(), sp=SpatialAutocorr) -@gpu_dispatch("rapids_singlecell.squidpy_gpu") +@gpu_dispatch() def spatial_autocorr( adata: AnnData | SpatialData, connectivity_key: str = Key.obsp.spatial_conn(), @@ -106,11 +107,11 @@ def spatial_autocorr( Layer in :attr:`anndata.AnnData.layers` to use. If `None`, use :attr:`anndata.AnnData.X`. attr Which attribute of :class:`~anndata.AnnData` to access. See ``genes`` parameter for more information. - %(seed)s + Ignored when ``device='gpu'``. + %(seed_device)s %(copy)s - %(parallelize)s - device - Device for computation: ``"cpu"``, ``"gpu"``, or ``None`` (use ``settings.device``). + %(parallelize_device)s + %(device)s Returns ------- @@ -346,7 +347,7 @@ def _co_occurrence_helper(v_x: NDArrayA, v_y: NDArrayA, v_radium: NDArrayA, labs @d.dedent -@gpu_dispatch("rapids_singlecell.squidpy_gpu") +@gpu_dispatch() def co_occurrence( adata: AnnData | SpatialData, cluster_key: str, @@ -373,10 +374,9 @@ def co_occurrence( %(copy)s n_splits Number of splits in which to divide the spatial coordinates in - :attr:`anndata.AnnData.obsm` ``['{spatial_key}']``. - %(parallelize)s - device - Device for computation: ``"cpu"``, ``"gpu"``, or ``None`` (use ``settings.device``). + :attr:`anndata.AnnData.obsm` ``['{spatial_key}']``. Ignored when ``device='gpu'``. + %(parallelize_device)s + %(device)s Returns ------- diff --git a/src/squidpy/settings/__init__.py b/src/squidpy/settings/__init__.py index 3491e01fd..0728d68ef 100644 --- a/src/squidpy/settings/__init__.py +++ b/src/squidpy/settings/__init__.py @@ -1,7 +1,6 @@ -"""Squidpy settings and configuration.""" - -from __future__ import annotations +"""Squidpy settings.""" +from squidpy.settings._dispatch import gpu_dispatch from squidpy.settings._settings import DeviceType, settings -__all__ = ["settings", "DeviceType"] +__all__ = ["settings", "DeviceType", "gpu_dispatch"] diff --git a/src/squidpy/settings/_dispatch.py b/src/squidpy/settings/_dispatch.py new file mode 100644 index 000000000..394e57077 --- /dev/null +++ b/src/squidpy/settings/_dispatch.py @@ -0,0 +1,62 @@ +"""GPU dispatch decorator for squidpy.""" + +from __future__ import annotations + +import functools +import inspect +from collections.abc import Callable +from typing import Any, Literal, TypeVar + +from squidpy.settings._settings import settings + +__all__ = ["gpu_dispatch"] + +F = TypeVar("F", bound=Callable[..., Any]) + + +def _resolve_device(device: Literal["auto", "cpu", "gpu"] | None) -> Literal["cpu", "gpu"]: + """Resolve device arg to 'cpu' or 'gpu'.""" + if device is None: + device = settings.device + if device == "cpu": + return "cpu" + if device == "gpu": + if not settings.gpu_available(): + raise RuntimeError("GPU unavailable. Install with: pip install squidpy[gpu-cuda12]") + return "gpu" + # auto + return "gpu" if settings.gpu_available() else "cpu" + + +def gpu_dispatch(gpu_func_name: str | None = None) -> Callable[[F], F]: + """Decorator to dispatch to GPU adapter when device='gpu'.""" + + def decorator(func: F) -> F: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + sig = inspect.signature(func) + try: + bound = sig.bind(*args, **kwargs) + except TypeError: + return func(*args, **kwargs) + + bound.apply_defaults() + all_args = dict(bound.arguments) + + device = all_args.pop("device", None) + + # Handle **kwargs: unpack instead of passing as kwargs=dict + extra_kwargs = all_args.pop("kwargs", {}) + + if _resolve_device(device) == "gpu": + from squidpy.gr import _gpu + + func_name = gpu_func_name if gpu_func_name is not None else f"{func.__name__}_gpu" + gpu_adapter = getattr(_gpu, func_name) + return gpu_adapter(**all_args, **extra_kwargs) + + return func(**all_args, **extra_kwargs) + + return wrapper # type: ignore[return-value] + + return decorator diff --git a/src/squidpy/settings/_settings.py b/src/squidpy/settings/_settings.py index d807f0714..01a00c739 100644 --- a/src/squidpy/settings/_settings.py +++ b/src/squidpy/settings/_settings.py @@ -1,3 +1,5 @@ +"""Squidpy global settings.""" + from __future__ import annotations from contextvars import ContextVar @@ -6,46 +8,34 @@ __all__ = ["settings", "DeviceType"] DeviceType = Literal["auto", "cpu", "gpu"] - _device_var: ContextVar[DeviceType] = ContextVar("device", default="auto") -class _SqSettings: - """Global settings for squidpy.""" +class SqSettings: + """Global configuration for squidpy.""" @property def device(self) -> DeviceType: - """Current compute device setting.""" + """Compute device: ``'auto'``, ``'cpu'``, or ``'gpu'``.""" return _device_var.get() @device.setter def device(self, value: DeviceType) -> None: - valid = get_args(DeviceType) - if value not in valid: - raise ValueError(f"Invalid device {value!r}. Must be one of: {valid}") + if value not in get_args(DeviceType): + raise ValueError(f"device must be one of {get_args(DeviceType)}, got {value!r}") if value == "gpu" and not self.gpu_available(): - raise RuntimeError( - "Cannot set device='gpu': rapids-singlecell not installed. " - "Install with: pip install squidpy[gpu-cuda12] or squidpy[gpu-cuda11]" - ) + raise RuntimeError("GPU unavailable. Install: pip install squidpy[gpu-cuda12]") _device_var.set(value) @staticmethod def gpu_available() -> bool: - """ - Check if GPU acceleration is available. - - Returns - ------- - bool - True if rapids-singlecell is installed and importable. - """ + """Check if GPU acceleration is available.""" try: import rapids_singlecell # noqa: F401 - return True except ImportError: return False -settings = _SqSettings() + +settings = SqSettings() diff --git a/tests/test_settings.py b/tests/test_settings.py index fe006236f..1055a7dca 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -6,8 +6,7 @@ import pytest -from squidpy._utils import gpu_dispatch -from squidpy.settings import settings +from squidpy.settings import gpu_dispatch, settings class TestSettings: @@ -27,17 +26,16 @@ def test_set_device_cpu(self): def test_set_device_invalid(self): """Test that invalid device raises ValueError.""" - with pytest.raises(ValueError, match="Invalid device"): + with pytest.raises(ValueError, match="device must be one of"): settings.device = "invalid" def test_set_device_gpu_without_rsc(self): """Test that setting device to 'gpu' without rapids-singlecell raises RuntimeError.""" # This will fail if rapids-singlecell is not installed if not settings.gpu_available(): - with pytest.raises(RuntimeError, match="rapids-singlecell not installed"): + with pytest.raises(RuntimeError, match="GPU unavailable"): settings.device = "gpu" - class TestGpuDispatch: """Test the gpu_dispatch decorator.""" @@ -45,7 +43,7 @@ def test_cpu_path_calls_original(self): """Test that CPU device calls the original function.""" original_called = [] - @gpu_dispatch("fake_rapids_module") + @gpu_dispatch() def my_func(x, y, *, n_jobs=1, device=None): original_called.append((x, y, n_jobs)) return x + y @@ -58,7 +56,7 @@ def test_cpu_path_with_auto_device_no_gpu(self): """Test that auto device falls back to CPU when GPU unavailable.""" original_called = [] - @gpu_dispatch("fake_rapids_module") + @gpu_dispatch() def my_func(x, device=None): original_called.append(x) return x * 2 @@ -69,38 +67,28 @@ def my_func(x, device=None): assert result == 10 assert original_called == [5] - def test_gpu_path_filters_parameters(self): - """Test that GPU dispatch filters out parameters not in rapids signature.""" - mock_rapids_func = MagicMock(return_value="gpu_result") - - # Create a mock module - mock_module = MagicMock() - mock_module.my_func = mock_rapids_func + def test_gpu_path_calls_adapter(self): + """Test that GPU dispatch calls the adapter function from _gpu module.""" + mock_adapter = MagicMock(return_value="gpu_result") - @gpu_dispatch("mock_rapids") + @gpu_dispatch() def my_func(adata, cluster_key, *, n_jobs=1, backend="loky", device=None): return "cpu_result" - with patch("importlib.import_module", return_value=mock_module): - with patch("squidpy._utils.resolve_device_arg", return_value="gpu"): - # Mock the rapids function signature to only accept adata and cluster_key - import inspect - - mock_sig = inspect.signature(lambda adata, cluster_key: None) - with patch( - "inspect.signature", - side_effect=lambda f: mock_sig if f == mock_rapids_func else inspect.signature(f), - ): - result = my_func("adata_obj", "leiden", n_jobs=4, backend="threading", device="gpu") + with patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"): + with patch("squidpy.gr._gpu.my_func_gpu", mock_adapter, create=True): + result = my_func("adata_obj", "leiden", n_jobs=4, backend="threading", device="gpu") assert result == "gpu_result" - # Should only be called with adata and cluster_key, not n_jobs or backend - mock_rapids_func.assert_called_once_with(adata="adata_obj", cluster_key="leiden") + # Adapter receives all args except device + mock_adapter.assert_called_once_with( + adata="adata_obj", cluster_key="leiden", n_jobs=4, backend="threading" + ) def test_preserves_function_metadata(self): """Test that the decorator preserves function name and docstring.""" - @gpu_dispatch("fake_module") + @gpu_dispatch() def documented_func(x, device=None): """This is the docstring.""" return x @@ -108,26 +96,17 @@ def documented_func(x, device=None): assert documented_func.__name__ == "documented_func" assert documented_func.__doc__ == """This is the docstring.""" - def test_custom_rapids_func_name(self): - """Test using a custom rapids function name.""" - mock_rapids_func = MagicMock(return_value="rapids_result") - mock_module = MagicMock() - mock_module.different_name = mock_rapids_func + def test_custom_gpu_func_name(self): + """Test using a custom GPU adapter function name.""" + mock_adapter = MagicMock(return_value="gpu_result") - @gpu_dispatch("mock_rapids", rapids_func_name="different_name") + @gpu_dispatch("custom_adapter_name") def my_func(x, device=None): return "cpu_result" - with patch("importlib.import_module", return_value=mock_module): - with patch("squidpy._utils.resolve_device_arg", return_value="gpu"): - import inspect + with patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"): + with patch("squidpy.gr._gpu.custom_adapter_name", mock_adapter, create=True): + result = my_func(42, device="gpu") - mock_sig = inspect.signature(lambda x: None) - with patch( - "inspect.signature", - side_effect=lambda f: mock_sig if f == mock_rapids_func else inspect.signature(f), - ): - result = my_func(42, device="gpu") - - assert result == "rapids_result" - mock_rapids_func.assert_called_once_with(x=42) + assert result == "gpu_result" + mock_adapter.assert_called_once_with(x=42) From 0dffb8b40a3ef9ac7f9ed67fd54b68bdf128506d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jan 2026 14:28:19 +0000 Subject: [PATCH 14/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/_utils.py | 2 -- src/squidpy/gr/_gpu.py | 29 +++++++++++++++++++++-------- src/squidpy/gr/_ligrec.py | 2 +- src/squidpy/gr/_ppatterns.py | 2 +- src/squidpy/settings/__init__.py | 2 ++ src/squidpy/settings/_settings.py | 2 +- tests/test_settings.py | 5 ++--- 7 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 24d5765de..736c88172 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -387,5 +387,3 @@ def _ensure_dim_order(img_da: xr.DataArray, order: Literal["cyx", "yxc"] = "yxc" img_da = img_da.expand_dims({"c": [0]}) # After possible expand, just transpose to target return img_da.transpose(*tuple(order)) - - diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py index f1006d260..c36ca61bc 100644 --- a/src/squidpy/gr/_gpu.py +++ b/src/squidpy/gr/_gpu.py @@ -7,9 +7,9 @@ from __future__ import annotations import warnings -from collections.abc import Callable +from collections.abc import Callable, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Literal, Sequence +from typing import TYPE_CHECKING, Any, Literal from squidpy._constants._pkg_constants import Key @@ -150,7 +150,11 @@ def spatial_autocorr_gpu( check = check_gpu_params( "spatial_autocorr", - attr=attr, seed=seed, n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, + attr=attr, + seed=seed, + n_jobs=n_jobs, + backend=backend, + show_progress_bar=show_progress_bar, ) return _spatial_autocorr_gpu( @@ -186,7 +190,10 @@ def co_occurrence_gpu( check_gpu_params( "co_occurrence", - n_splits=n_splits, n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, + n_splits=n_splits, + n_jobs=n_jobs, + backend=backend, + show_progress_bar=show_progress_bar, ) return _co_occurrence_gpu( @@ -228,10 +235,16 @@ def ligrec_gpu( check_gpu_params( "ligrec", - clusters=clusters, seed=seed, numba_parallel=numba_parallel, - n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, - transmitter_params=transmitter_params, receiver_params=receiver_params, - interactions_params=interactions_params, alpha=alpha, + clusters=clusters, + seed=seed, + numba_parallel=numba_parallel, + n_jobs=n_jobs, + backend=backend, + show_progress_bar=show_progress_bar, + transmitter_params=transmitter_params, + receiver_params=receiver_params, + interactions_params=interactions_params, + alpha=alpha, ) return _ligrec_gpu( diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index ed610fa20..9f3dd0fdb 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -21,7 +21,6 @@ from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize -from squidpy.settings import gpu_dispatch from squidpy.gr._utils import ( _assert_categorical_obs, _assert_positive, @@ -30,6 +29,7 @@ _genesymbols, _save_data, ) +from squidpy.settings import gpu_dispatch __all__ = ["ligrec", "PermutationTest"] diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index b05298be3..40e45ef40 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -24,7 +24,6 @@ from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize -from squidpy.settings import gpu_dispatch from squidpy.gr._utils import ( _assert_categorical_obs, _assert_connectivity_key, @@ -32,6 +31,7 @@ _assert_spatial_basis, _save_data, ) +from squidpy.settings import gpu_dispatch __all__ = ["spatial_autocorr", "co_occurrence"] diff --git a/src/squidpy/settings/__init__.py b/src/squidpy/settings/__init__.py index 0728d68ef..aa1bf6afe 100644 --- a/src/squidpy/settings/__init__.py +++ b/src/squidpy/settings/__init__.py @@ -1,5 +1,7 @@ """Squidpy settings.""" +from __future__ import annotations + from squidpy.settings._dispatch import gpu_dispatch from squidpy.settings._settings import DeviceType, settings diff --git a/src/squidpy/settings/_settings.py b/src/squidpy/settings/_settings.py index 01a00c739..fa584e6d3 100644 --- a/src/squidpy/settings/_settings.py +++ b/src/squidpy/settings/_settings.py @@ -32,10 +32,10 @@ def gpu_available() -> bool: """Check if GPU acceleration is available.""" try: import rapids_singlecell # noqa: F401 + return True except ImportError: return False - settings = SqSettings() diff --git a/tests/test_settings.py b/tests/test_settings.py index 1055a7dca..d06790efc 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -36,6 +36,7 @@ def test_set_device_gpu_without_rsc(self): with pytest.raises(RuntimeError, match="GPU unavailable"): settings.device = "gpu" + class TestGpuDispatch: """Test the gpu_dispatch decorator.""" @@ -81,9 +82,7 @@ def my_func(adata, cluster_key, *, n_jobs=1, backend="loky", device=None): assert result == "gpu_result" # Adapter receives all args except device - mock_adapter.assert_called_once_with( - adata="adata_obj", cluster_key="leiden", n_jobs=4, backend="threading" - ) + mock_adapter.assert_called_once_with(adata="adata_obj", cluster_key="leiden", n_jobs=4, backend="threading") def test_preserves_function_metadata(self): """Test that the decorator preserves function name and docstring.""" From de83264d756b672c920662c3f430e63671e86b1a Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 26 Jan 2026 15:38:34 +0100 Subject: [PATCH 15/68] use try except for compar --- src/squidpy/_compat.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/squidpy/_compat.py b/src/squidpy/_compat.py index e56040310..7a51563f0 100644 --- a/src/squidpy/_compat.py +++ b/src/squidpy/_compat.py @@ -20,11 +20,11 @@ ] # See https://github.com/scverse/squidpy/issues/1061 for more details -_SET_DEFAULT_COLORS_FOR_CATEGORICAL_OBS_CHANGED = Version("0.12.0rc1") <= Version(version("scanpy")) < Version("0.12.0") - -if _SET_DEFAULT_COLORS_FOR_CATEGORICAL_OBS_CHANGED: +# In scanpy 0.12.0rc1 through 1.11.x, the function is _set_default_colors_for_categorical_obs (with underscore) +# Try the underscore version first (current), fall back to non-underscore for older/future versions +try: from scanpy.plotting._utils import _set_default_colors_for_categorical_obs as set_default_colors_for_categorical_obs -else: +except ImportError: from scanpy.plotting._utils import set_default_colors_for_categorical_obs From 1d7eb487a22241f34aa3bf2a50b323e2d92e65e4 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 26 Jan 2026 15:40:38 +0100 Subject: [PATCH 16/68] interphinx mapping --- docs/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/conf.py b/docs/conf.py index 0f413dfac..85080bd23 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -66,6 +66,7 @@ napari=("https://napari.org", None), spatialdata=("https://spatialdata.scverse.org/en/latest", None), shapely=("https://shapely.readthedocs.io/en/stable", None), + rapids_singlecell=("https://rapids-singlecell.readthedocs.io/en/latest", None), ) # Add any paths that contain templates here, relative to this directory. From 17d34bd8a26636c72324984b781fd140d89b43e6 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 26 Jan 2026 16:43:10 +0100 Subject: [PATCH 17/68] rsc link refer workaround --- src/squidpy/_docs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/squidpy/_docs.py b/src/squidpy/_docs.py index d931b6e19..633be2124 100644 --- a/src/squidpy/_docs.py +++ b/src/squidpy/_docs.py @@ -132,7 +132,7 @@ def decorator2(obj: Any) -> Any: _device = """\ device Device for computation: ``'cpu'``, ``'gpu'``, or ``None`` (use ``squidpy.settings.device``). - When ``'gpu'``, dispatches to :mod:`rapids_singlecell` for GPU-accelerated computation.""" + When ``'gpu'``, dispatches to :doc:`rapids_singlecell ` for GPU-accelerated computation.""" _channels = """\ channels Channels for this feature is computed. If `None`, use all channels.""" From 89be7527697090cbc6f6658fc5225fa579ef1f50 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 26 Jan 2026 17:16:09 +0100 Subject: [PATCH 18/68] ad gpu notes --- src/squidpy/settings/_dispatch.py | 36 ++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/src/squidpy/settings/_dispatch.py b/src/squidpy/settings/_dispatch.py index 394e57077..8270b1d32 100644 --- a/src/squidpy/settings/_dispatch.py +++ b/src/squidpy/settings/_dispatch.py @@ -4,6 +4,7 @@ import functools import inspect +import re from collections.abc import Callable from typing import Any, Literal, TypeVar @@ -13,6 +14,12 @@ F = TypeVar("F", bound=Callable[..., Any]) +_GPU_NOTE_TEMPLATE = """ +.. note:: + This function supports GPU acceleration via :doc:`rapids_singlecell `. + See :func:`rapids_singlecell.squidpy_gpu.{func_name}` for the GPU implementation. +""" + def _resolve_device(device: Literal["auto", "cpu", "gpu"] | None) -> Literal["cpu", "gpu"]: """Resolve device arg to 'cpu' or 'gpu'.""" @@ -28,10 +35,33 @@ def _resolve_device(device: Literal["auto", "cpu", "gpu"] | None) -> Literal["cp return "gpu" if settings.gpu_available() else "cpu" +def _inject_gpu_note(doc: str | None, func_name: str) -> str | None: + """Inject GPU note into docstring after the first paragraph.""" + if doc is None: + return None + + gpu_note = _GPU_NOTE_TEMPLATE.format(func_name=func_name) + + # Find "Parameters\n----------" and insert note before it + match = re.search(r"(\n\s*Parameters\s*\n\s*-+)", doc) + if match: + insert_pos = match.start() + return doc[:insert_pos] + "\n" + gpu_note + doc[insert_pos:] + + # Fallback: append at the end + return doc + "\n" + gpu_note + + def gpu_dispatch(gpu_func_name: str | None = None) -> Callable[[F], F]: - """Decorator to dispatch to GPU adapter when device='gpu'.""" + """Decorator to dispatch to GPU adapter when device='gpu'. + + Also injects a GPU note into the function's docstring. + """ def decorator(func: F) -> F: + # Inject GPU note into docstring + func.__doc__ = _inject_gpu_note(func.__doc__, func.__name__) + @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: sig = inspect.signature(func) @@ -51,8 +81,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: if _resolve_device(device) == "gpu": from squidpy.gr import _gpu - func_name = gpu_func_name if gpu_func_name is not None else f"{func.__name__}_gpu" - gpu_adapter = getattr(_gpu, func_name) + adapter_name = gpu_func_name if gpu_func_name is not None else f"{func.__name__}_gpu" + gpu_adapter = getattr(_gpu, adapter_name) return gpu_adapter(**all_args, **extra_kwargs) return func(**all_args, **extra_kwargs) From 5a7cb83c9e0ffbb85b84f336978a5c8d260ae354 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 26 Jan 2026 17:24:11 +0100 Subject: [PATCH 19/68] ok fix docs --- src/squidpy/settings/_dispatch.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/squidpy/settings/_dispatch.py b/src/squidpy/settings/_dispatch.py index 8270b1d32..fafc5bc5d 100644 --- a/src/squidpy/settings/_dispatch.py +++ b/src/squidpy/settings/_dispatch.py @@ -14,13 +14,6 @@ F = TypeVar("F", bound=Callable[..., Any]) -_GPU_NOTE_TEMPLATE = """ -.. note:: - This function supports GPU acceleration via :doc:`rapids_singlecell `. - See :func:`rapids_singlecell.squidpy_gpu.{func_name}` for the GPU implementation. -""" - - def _resolve_device(device: Literal["auto", "cpu", "gpu"] | None) -> Literal["cpu", "gpu"]: """Resolve device arg to 'cpu' or 'gpu'.""" if device is None: @@ -35,21 +28,30 @@ def _resolve_device(device: Literal["auto", "cpu", "gpu"] | None) -> Literal["cp return "gpu" if settings.gpu_available() else "cpu" +def _make_gpu_note(func_name: str, indent: str = "") -> str: + lines = [ + ".. note::", + " This function supports GPU acceleration via :doc:`rapids_singlecell `.", + f" See :func:`rapids_singlecell.gr.{func_name}` for the GPU implementation.", + ] + return "\n".join(indent + line for line in lines) + + def _inject_gpu_note(doc: str | None, func_name: str) -> str | None: - """Inject GPU note into docstring after the first paragraph.""" + """Inject GPU note into docstring before the Parameters section.""" if doc is None: return None - gpu_note = _GPU_NOTE_TEMPLATE.format(func_name=func_name) - - # Find "Parameters\n----------" and insert note before it - match = re.search(r"(\n\s*Parameters\s*\n\s*-+)", doc) + # Find "Parameters\n ----------" and capture the indentation (spaces only, not newline) + match = re.search(r"\n([ \t]*)Parameters\s*\n\s*-+", doc) if match: + indent = match.group(1) # Capture only the spaces/tabs before Parameters + gpu_note = _make_gpu_note(func_name, indent) insert_pos = match.start() - return doc[:insert_pos] + "\n" + gpu_note + doc[insert_pos:] + return doc[:insert_pos] + "\n\n" + gpu_note + "\n" + doc[insert_pos:] # Fallback: append at the end - return doc + "\n" + gpu_note + return doc + "\n\n" + _make_gpu_note(func_name) def gpu_dispatch(gpu_func_name: str | None = None) -> Callable[[F], F]: From 7478f62936d39c362cde79f92c124945696b6904 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:26:02 +0000 Subject: [PATCH 20/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/settings/_dispatch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/squidpy/settings/_dispatch.py b/src/squidpy/settings/_dispatch.py index fafc5bc5d..7fc967716 100644 --- a/src/squidpy/settings/_dispatch.py +++ b/src/squidpy/settings/_dispatch.py @@ -14,6 +14,7 @@ F = TypeVar("F", bound=Callable[..., Any]) + def _resolve_device(device: Literal["auto", "cpu", "gpu"] | None) -> Literal["cpu", "gpu"]: """Resolve device arg to 'cpu' or 'gpu'.""" if device is None: From 74dc1ff3214b89c33776ea4195b53a503b172df9 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 26 Jan 2026 17:27:04 +0100 Subject: [PATCH 21/68] update docs --- src/squidpy/_compat.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/squidpy/_compat.py b/src/squidpy/_compat.py index 7a51563f0..a342c3b91 100644 --- a/src/squidpy/_compat.py +++ b/src/squidpy/_compat.py @@ -20,8 +20,9 @@ ] # See https://github.com/scverse/squidpy/issues/1061 for more details -# In scanpy 0.12.0rc1 through 1.11.x, the function is _set_default_colors_for_categorical_obs (with underscore) -# Try the underscore version first (current), fall back to non-underscore for older/future versions +# scanpy around version 0.11.x- 0.12.x changed the function name from set_default_colors_for_categorical_obs to _set_default_colors_for_categorical_obs +# and then changed it back +# so to not track with versioning we use the underscore version first (current), fall back to non-underscore for older/future versions try: from scanpy.plotting._utils import _set_default_colors_for_categorical_obs as set_default_colors_for_categorical_obs except ImportError: From f1b506eb2317d0807a7916d7a2d48564111902ac Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 26 Jan 2026 17:55:04 +0100 Subject: [PATCH 22/68] refactor --- .gitignore | 1 + src/squidpy/gr/_gpu.py | 206 ++++++------------------------ src/squidpy/gr/_ppatterns.py | 3 + src/squidpy/settings/_dispatch.py | 49 ++++++- tests/test_settings.py | 146 +++++++++++++++------ 5 files changed, 189 insertions(+), 216 deletions(-) diff --git a/.gitignore b/.gitignore index b638264c7..6a7b3fbd4 100644 --- a/.gitignore +++ b/.gitignore @@ -144,3 +144,4 @@ data # pixi .pixi pixi.lock +_version.py diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py index c36ca61bc..c747cb10a 100644 --- a/src/squidpy/gr/_gpu.py +++ b/src/squidpy/gr/_gpu.py @@ -1,21 +1,16 @@ -"""GPU adapter functions for squidpy.gr functions. +"""GPU parameter registry for squidpy.gr functions. -These stubs provide explicit parameter mapping between squidpy and rapids_singlecell, -ensuring compatibility and clear documentation of supported parameters. +Defines which parameters are CPU-only (ignored on GPU) and GPU-only (ignored on CPU). +The gpu_dispatch decorator uses this registry to automatically handle parameter filtering. """ from __future__ import annotations -import warnings -from collections.abc import Callable, Sequence -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Literal +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any -from squidpy._constants._pkg_constants import Key - -if TYPE_CHECKING: - from anndata import AnnData - from numpy.typing import NDArray +__all__ = ["GPU_PARAM_REGISTRY", "GpuParamSpec", "check_gpu_params", "check_cpu_params"] @dataclass @@ -43,6 +38,8 @@ def _attr_validator(value: Any) -> str | None: _SEED: dict[str, GpuParamSpec] = {"seed": GpuParamSpec(None)} # Registry: {func_name: {"cpu_only": {...}, "gpu_only": {...}}} +# - cpu_only: parameters ignored on GPU (warn if non-default, then filter out) +# - gpu_only: parameters ignored on CPU (error if non-default, pass through to GPU) GPU_PARAM_REGISTRY: dict[str, dict[str, dict[str, GpuParamSpec]]] = { "spatial_autocorr": { "cpu_only": { @@ -77,17 +74,8 @@ def _attr_validator(value: Any) -> str | None: } -@dataclass -class CheckResult: - """Result of parameter compatibility check.""" - - ignored: dict[str, Any] = field(default_factory=dict) - warnings: list[str] = field(default_factory=list) - gpu_defaults: dict[str, Any] = field(default_factory=dict) - - -def check_gpu_params(func_name: str, **cpu_only_values: Any) -> CheckResult: - """Check CPU-only params against registry, warn about non-defaults, return GPU defaults. +def check_gpu_params(func_name: str, **cpu_only_values: Any) -> None: + """Check CPU-only params on GPU, raise error if non-default. Parameters ---------- @@ -95,11 +83,14 @@ def check_gpu_params(func_name: str, **cpu_only_values: Any) -> CheckResult: Name of the function in GPU_PARAM_REGISTRY. **cpu_only_values CPU-only parameter values to check. + + Raises + ------ + ValueError + If a CPU-only parameter has a non-default value on GPU. """ - result = CheckResult() registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) - # Check CPU-only params for name, spec in registry["cpu_only"].items(): if name not in cpu_only_values: continue @@ -109,155 +100,36 @@ def check_gpu_params(func_name: str, **cpu_only_values: Any) -> CheckResult: if spec.validator: msg = spec.validator(value) elif value != spec.default: - msg = spec.message or f"{name}={value!r} is ignored on GPU." + msg = spec.message or f"{name}={value!r} is only supported on CPU. Use device='cpu' or remove this argument." else: msg = None if msg: - msg = msg.format(name=name, value=value) - result.ignored[name] = value - result.warnings.append(msg) - warnings.warn(msg, UserWarning, stacklevel=3) - - # Collect GPU-only param defaults - for name, spec in registry["gpu_only"].items(): - result.gpu_defaults[name] = spec.default - - return result - - -def spatial_autocorr_gpu( - adata: AnnData, - connectivity_key: str = Key.obsp.spatial_conn(), - genes: str | int | Sequence[str] | Sequence[int] | None = None, - mode: Literal["moran", "geary"] = "moran", - transformation: bool = True, - n_perms: int | None = None, - two_tailed: bool = False, - corr_method: str | None = "fdr_bh", - layer: str | None = None, - use_raw: bool = False, - copy: bool = False, - # CPU-only params - attr: Literal["obs", "X", "obsm"] = "X", - seed: int | None = None, - n_jobs: int | None = None, - backend: str = "loky", - show_progress_bar: bool = True, -) -> Any: - """GPU adapter for spatial_autocorr via rapids_singlecell.""" - from rapids_singlecell.squidpy_gpu import spatial_autocorr as _spatial_autocorr_gpu - - check = check_gpu_params( - "spatial_autocorr", - attr=attr, - seed=seed, - n_jobs=n_jobs, - backend=backend, - show_progress_bar=show_progress_bar, - ) + raise ValueError(msg.format(name=name, value=value)) - return _spatial_autocorr_gpu( - adata=adata, - connectivity_key=connectivity_key, - genes=genes, - mode=mode, - transformation=transformation, - n_perms=n_perms, - two_tailed=two_tailed, - corr_method=corr_method, - layer=layer, - use_raw=use_raw, - copy=copy, - **check.gpu_defaults, - ) +def check_cpu_params(func_name: str, **gpu_only_values: Any) -> None: + """Check GPU-only params on CPU, raise error if non-default. -def co_occurrence_gpu( - adata: AnnData, - cluster_key: str, - spatial_key: str = Key.obsm.spatial, - interval: int | NDArray[Any] = 50, - copy: bool = False, - # CPU-only params - n_splits: int | None = None, - n_jobs: int | None = None, - backend: str = "loky", - show_progress_bar: bool = True, -) -> Any: - """GPU adapter for co_occurrence via rapids_singlecell.""" - from rapids_singlecell.squidpy_gpu import co_occurrence as _co_occurrence_gpu - - check_gpu_params( - "co_occurrence", - n_splits=n_splits, - n_jobs=n_jobs, - backend=backend, - show_progress_bar=show_progress_bar, - ) - - return _co_occurrence_gpu( - adata=adata, - cluster_key=cluster_key, - spatial_key=spatial_key, - interval=interval, - copy=copy, - ) - + Parameters + ---------- + func_name + Name of the function in GPU_PARAM_REGISTRY. + **gpu_only_values + GPU-only parameter values to check. -def ligrec_gpu( - adata: AnnData, - cluster_key: str, - interactions: Any = None, - complex_policy: Literal["min", "all"] = "min", - threshold: float = 0.01, - corr_method: str | None = None, - corr_axis: Literal["interactions", "clusters"] = "clusters", - use_raw: bool = True, - copy: bool = False, - key_added: str | None = None, - gene_symbols: str | None = None, - n_perms: int = 1000, - # CPU-only params - clusters: Any = None, - seed: int | None = None, - numba_parallel: bool | None = None, - n_jobs: int | None = None, - backend: str = "loky", - show_progress_bar: bool = True, - transmitter_params: dict[str, Any] | None = None, - receiver_params: dict[str, Any] | None = None, - interactions_params: dict[str, Any] | None = None, - alpha: float = 0.05, -) -> Any: - """GPU adapter for ligrec via rapids_singlecell.""" - from rapids_singlecell.squidpy_gpu import ligrec as _ligrec_gpu + Raises + ------ + ValueError + If a GPU-only parameter has a non-default value on CPU. + """ + registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) - check_gpu_params( - "ligrec", - clusters=clusters, - seed=seed, - numba_parallel=numba_parallel, - n_jobs=n_jobs, - backend=backend, - show_progress_bar=show_progress_bar, - transmitter_params=transmitter_params, - receiver_params=receiver_params, - interactions_params=interactions_params, - alpha=alpha, - ) + for name, spec in registry["gpu_only"].items(): + if name not in gpu_only_values: + continue + value = gpu_only_values[name] - return _ligrec_gpu( - adata=adata, - cluster_key=cluster_key, - interactions=interactions, - complex_policy=complex_policy, - threshold=threshold, - corr_method=corr_method, - corr_axis=corr_axis, - use_raw=use_raw, - copy=copy, - key_added=key_added, - gene_symbols=gene_symbols, - n_perms=n_perms, - ) + if value != spec.default: + msg = f"{name}={value!r} is only supported on GPU. Use device='gpu' or remove this argument." + raise ValueError(msg) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 40e45ef40..c9b718089 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -60,6 +60,7 @@ def spatial_autocorr( layer: str | None = None, seed: int | None = None, use_raw: bool = False, + use_sparse: bool = True, copy: bool = False, n_jobs: int | None = None, backend: str = "loky", @@ -108,6 +109,8 @@ def spatial_autocorr( attr Which attribute of :class:`~anndata.AnnData` to access. See ``genes`` parameter for more information. Ignored when ``device='gpu'``. + use_sparse + If `True`, use sparse matrix representation for the input matrix. Only used when ``device='gpu'``. %(seed_device)s %(copy)s %(parallelize_device)s diff --git a/src/squidpy/settings/_dispatch.py b/src/squidpy/settings/_dispatch.py index 7fc967716..f831594a0 100644 --- a/src/squidpy/settings/_dispatch.py +++ b/src/squidpy/settings/_dispatch.py @@ -3,11 +3,13 @@ from __future__ import annotations import functools +import importlib import inspect import re from collections.abc import Callable from typing import Any, Literal, TypeVar +from squidpy.gr._gpu import GPU_PARAM_REGISTRY, check_cpu_params, check_gpu_params from squidpy.settings._settings import settings __all__ = ["gpu_dispatch"] @@ -55,15 +57,31 @@ def _inject_gpu_note(doc: str | None, func_name: str) -> str | None: return doc + "\n\n" + _make_gpu_note(func_name) -def gpu_dispatch(gpu_func_name: str | None = None) -> Callable[[F], F]: - """Decorator to dispatch to GPU adapter when device='gpu'. +def gpu_dispatch( + gpu_module: str = "rapids_singlecell.gr", + gpu_func_name: str | None = None, +) -> Callable[[F], F]: + """Decorator to dispatch to GPU implementation when device='gpu'. + + Uses the GPU_PARAM_REGISTRY from squidpy.gr._gpu to: + - Warn about CPU-only parameters that differ from defaults, then filter them out + - Filter out GPU-only parameters on CPU (they only affect GPU) Also injects a GPU note into the function's docstring. + + Parameters + ---------- + gpu_module + Module path containing the GPU implementation. + gpu_func_name + Name of GPU function. Defaults to same name as decorated function. """ def decorator(func: F) -> F: + func_name = func.__name__ + # Inject GPU note into docstring - func.__doc__ = _inject_gpu_note(func.__doc__, func.__name__) + func.__doc__ = _inject_gpu_note(func.__doc__, func_name) @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: @@ -81,12 +99,29 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # Handle **kwargs: unpack instead of passing as kwargs=dict extra_kwargs = all_args.pop("kwargs", {}) + # Get registry for this function + registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) + if _resolve_device(device) == "gpu": - from squidpy.gr import _gpu + # Collect CPU-only param values and check them (warn if non-default) + cpu_only_values = {k: all_args.pop(k) for k in list(all_args) if k in registry["cpu_only"]} + cpu_only_values.update( + {k: extra_kwargs.pop(k) for k in list(extra_kwargs) if k in registry["cpu_only"]} + ) + + check_gpu_params(func_name, **cpu_only_values) + + # Import and call GPU function + module = importlib.import_module(gpu_module) + gpu_func = getattr(module, gpu_func_name or func_name) + + return gpu_func(**all_args, **extra_kwargs) + + # CPU path: check gpu_only params (error if non-default), then filter them out + gpu_only_values = {k: all_args.pop(k) for k in list(all_args) if k in registry["gpu_only"]} + gpu_only_values.update({k: extra_kwargs.pop(k) for k in list(extra_kwargs) if k in registry["gpu_only"]}) - adapter_name = gpu_func_name if gpu_func_name is not None else f"{func.__name__}_gpu" - gpu_adapter = getattr(_gpu, adapter_name) - return gpu_adapter(**all_args, **extra_kwargs) + check_cpu_params(func_name, **gpu_only_values) return func(**all_args, **extra_kwargs) diff --git a/tests/test_settings.py b/tests/test_settings.py index d06790efc..8fa8ccbfc 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -6,6 +6,7 @@ import pytest +from squidpy.gr._gpu import GpuParamSpec from squidpy.settings import gpu_dispatch, settings @@ -14,7 +15,6 @@ class TestSettings: def test_default_device(self): """Test that default device is 'auto'.""" - # Reset to default settings.device = "auto" assert settings.device == "auto" @@ -31,7 +31,6 @@ def test_set_device_invalid(self): def test_set_device_gpu_without_rsc(self): """Test that setting device to 'gpu' without rapids-singlecell raises RuntimeError.""" - # This will fail if rapids-singlecell is not installed if not settings.gpu_available(): with pytest.raises(RuntimeError, match="GPU unavailable"): settings.device = "gpu" @@ -40,72 +39,135 @@ def test_set_device_gpu_without_rsc(self): class TestGpuDispatch: """Test the gpu_dispatch decorator.""" - def test_cpu_path_calls_original(self): - """Test that CPU device calls the original function.""" - original_called = [] + @pytest.fixture + def mock_gpu_module(self): + """Create a mock GPU module with adapter function.""" + mock_adapter = MagicMock(return_value="gpu_result") + mock_module = MagicMock() + mock_module.my_func = mock_adapter + return mock_module, mock_adapter + + def test_cpu_path(self): + """Test CPU device calls original function.""" + calls = [] @gpu_dispatch() def my_func(x, y, *, n_jobs=1, device=None): - original_called.append((x, y, n_jobs)) + calls.append((x, y, n_jobs)) return x + y - result = my_func(1, 2, device="cpu") - assert result == 3 - assert original_called == [(1, 2, 1)] + assert my_func(1, 2, device="cpu") == 3 + assert calls == [(1, 2, 1)] + + def test_auto_device_falls_back_to_cpu(self): + """Test auto device falls back to CPU when GPU unavailable.""" + if settings.gpu_available(): + pytest.skip("GPU is available") - def test_cpu_path_with_auto_device_no_gpu(self): - """Test that auto device falls back to CPU when GPU unavailable.""" - original_called = [] + calls = [] @gpu_dispatch() def my_func(x, device=None): - original_called.append(x) + calls.append(x) return x * 2 - # With auto and no GPU available, should call original - if not settings.gpu_available(): - result = my_func(5, device="auto") - assert result == 10 - assert original_called == [5] + assert my_func(5, device="auto") == 10 + assert calls == [5] - def test_gpu_path_calls_adapter(self): - """Test that GPU dispatch calls the adapter function from _gpu module.""" - mock_adapter = MagicMock(return_value="gpu_result") + def test_gpu_path(self, mock_gpu_module): + """Test GPU device dispatches to GPU module.""" + mock_module, mock_adapter = mock_gpu_module - @gpu_dispatch() - def my_func(adata, cluster_key, *, n_jobs=1, backend="loky", device=None): + @gpu_dispatch(gpu_module="test_module") + def my_func(x, device=None): + return "cpu_result" + + with ( + patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), + patch("importlib.import_module", return_value=mock_module), + patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", {"my_func": {"cpu_only": {}, "gpu_only": {}}}), + ): + assert my_func(42, device="gpu") == "gpu_result" + + mock_adapter.assert_called_once_with(x=42) + + def test_custom_gpu_func_name(self, mock_gpu_module): + """Test custom GPU function name.""" + mock_module, mock_adapter = mock_gpu_module + mock_module.custom_name = mock_adapter + + @gpu_dispatch(gpu_module="test_module", gpu_func_name="custom_name") + def my_func(x, device=None): return "cpu_result" - with patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"): - with patch("squidpy.gr._gpu.my_func_gpu", mock_adapter, create=True): - result = my_func("adata_obj", "leiden", n_jobs=4, backend="threading", device="gpu") + with ( + patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), + patch("importlib.import_module", return_value=mock_module), + patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", {"my_func": {"cpu_only": {}, "gpu_only": {}}}), + ): + assert my_func(42, device="gpu") == "gpu_result" - assert result == "gpu_result" - # Adapter receives all args except device - mock_adapter.assert_called_once_with(adata="adata_obj", cluster_key="leiden", n_jobs=4, backend="threading") + mock_adapter.assert_called_once_with(x=42) def test_preserves_function_metadata(self): - """Test that the decorator preserves function name and docstring.""" + """Test decorator preserves function name and injects GPU note.""" @gpu_dispatch() def documented_func(x, device=None): - """This is the docstring.""" + """Original docstring. + + Parameters + ---------- + x + Input value. + """ return x assert documented_func.__name__ == "documented_func" - assert documented_func.__doc__ == """This is the docstring.""" + assert "Original docstring." in documented_func.__doc__ + assert "GPU acceleration" in documented_func.__doc__ + + def test_cpu_only_params_filtered_on_gpu(self, mock_gpu_module): + """Test CPU-only params are filtered out on GPU.""" + mock_module, mock_adapter = mock_gpu_module + registry = { + "my_func": { + "cpu_only": {"n_jobs": GpuParamSpec(None)}, + "gpu_only": {}, + } + } + + @gpu_dispatch(gpu_module="test_module") + def my_func(x, n_jobs=None, device=None): + return "cpu_result" - def test_custom_gpu_func_name(self): - """Test using a custom GPU adapter function name.""" - mock_adapter = MagicMock(return_value="gpu_result") + with ( + patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), + patch("importlib.import_module", return_value=mock_module), + patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry), + ): + my_func(42, n_jobs=4, device="gpu") - @gpu_dispatch("custom_adapter_name") - def my_func(x, device=None): + # n_jobs should be filtered out + mock_adapter.assert_called_once_with(x=42) + + def test_gpu_only_params_error_on_cpu_if_non_default(self): + """Test GPU-only params raise error on CPU if non-default.""" + registry = { + "my_func": { + "cpu_only": {}, + "gpu_only": {"use_sparse": GpuParamSpec(True)}, + } + } + + @gpu_dispatch(gpu_module="test_module") + def my_func(x, use_sparse=True, device=None): return "cpu_result" - with patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"): - with patch("squidpy.gr._gpu.custom_adapter_name", mock_adapter, create=True): - result = my_func(42, device="gpu") + with patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry): + # Default value works + assert my_func(42, use_sparse=True, device="cpu") == "cpu_result" - assert result == "gpu_result" - mock_adapter.assert_called_once_with(x=42) + # Non-default raises error + with pytest.raises(ValueError, match="use_sparse.*only supported on GPU"): + my_func(42, use_sparse=False, device="cpu") From 403f08f216bc9ad2777acb606ac18075352d82cd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:55:18 +0000 Subject: [PATCH 23/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_gpu.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py index c747cb10a..830b58a45 100644 --- a/src/squidpy/gr/_gpu.py +++ b/src/squidpy/gr/_gpu.py @@ -100,7 +100,9 @@ def check_gpu_params(func_name: str, **cpu_only_values: Any) -> None: if spec.validator: msg = spec.validator(value) elif value != spec.default: - msg = spec.message or f"{name}={value!r} is only supported on CPU. Use device='cpu' or remove this argument." + msg = ( + spec.message or f"{name}={value!r} is only supported on CPU. Use device='cpu' or remove this argument." + ) else: msg = None From 3aed99ea4576c8f7704952efb52bbe8132c36cce Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Tue, 27 Jan 2026 09:28:51 +0100 Subject: [PATCH 24/68] save wip --- src/squidpy/gr/_gpu.py | 62 ++++++++++++++++++++++--------- src/squidpy/gr/_ppatterns.py | 5 ++- src/squidpy/settings/_dispatch.py | 17 ++++++--- tests/test_settings.py | 47 +++++++++++++---------- 4 files changed, 85 insertions(+), 46 deletions(-) diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py index 830b58a45..3f9558d6d 100644 --- a/src/squidpy/gr/_gpu.py +++ b/src/squidpy/gr/_gpu.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from typing import Any -__all__ = ["GPU_PARAM_REGISTRY", "GpuParamSpec", "check_gpu_params", "check_cpu_params"] +__all__ = ["GPU_PARAM_REGISTRY", "GpuParamSpec", "check_gpu_params", "check_cpu_params", "apply_defaults"] @dataclass @@ -75,19 +75,19 @@ def _attr_validator(value: Any) -> str | None: def check_gpu_params(func_name: str, **cpu_only_values: Any) -> None: - """Check CPU-only params on GPU, raise error if non-default. + """Check CPU-only params on GPU, raise error if user provided a value. Parameters ---------- func_name Name of the function in GPU_PARAM_REGISTRY. **cpu_only_values - CPU-only parameter values to check. + CPU-only parameter values to check. None means not provided by user. Raises ------ ValueError - If a CPU-only parameter has a non-default value on GPU. + If a CPU-only parameter was explicitly provided (not None) on GPU. """ registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) @@ -96,34 +96,35 @@ def check_gpu_params(func_name: str, **cpu_only_values: Any) -> None: continue value = cpu_only_values[name] - # Use custom validator if provided, else default behavior + # None means user didn't provide a value - that's fine + if value is None: + continue + + # Use custom validator if provided if spec.validator: msg = spec.validator(value) - elif value != spec.default: - msg = ( - spec.message or f"{name}={value!r} is only supported on CPU. Use device='cpu' or remove this argument." - ) + if msg: + raise ValueError(msg.format(name=name, value=value)) else: - msg = None - - if msg: + # User explicitly provided a value for a CPU-only param on GPU + msg = spec.message or f"{name}={value!r} is only supported on CPU. Use device='cpu' or remove this argument." raise ValueError(msg.format(name=name, value=value)) def check_cpu_params(func_name: str, **gpu_only_values: Any) -> None: - """Check GPU-only params on CPU, raise error if non-default. + """Check GPU-only params on CPU, raise error if user provided a value. Parameters ---------- func_name Name of the function in GPU_PARAM_REGISTRY. **gpu_only_values - GPU-only parameter values to check. + GPU-only parameter values to check. None means not provided by user. Raises ------ ValueError - If a GPU-only parameter has a non-default value on CPU. + If a GPU-only parameter was explicitly provided (not None) on CPU. """ registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) @@ -132,6 +133,31 @@ def check_cpu_params(func_name: str, **gpu_only_values: Any) -> None: continue value = gpu_only_values[name] - if value != spec.default: - msg = f"{name}={value!r} is only supported on GPU. Use device='gpu' or remove this argument." - raise ValueError(msg) + # None means user didn't provide a value - that's fine + if value is None: + continue + + # User explicitly provided a value for a GPU-only param on CPU + msg = f"{name}={value!r} is only supported on GPU. Use device='gpu' or remove this argument." + raise ValueError(msg) + + +def apply_defaults(func_name: str, args: dict[str, Any], target: str) -> None: + """Apply registry defaults for params that are None. + + Parameters + ---------- + func_name + Name of the function in GPU_PARAM_REGISTRY. + args + Arguments dict to modify in place. + target + Either 'cpu' or 'gpu' - which defaults to apply. + """ + registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) + + # Apply defaults for the target's own params + param_key = "cpu_only" if target == "cpu" else "gpu_only" + for name, spec in registry[param_key].items(): + if name in args and args[name] is None: + args[name] = spec.default diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index c9b718089..8921c1674 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -60,7 +60,7 @@ def spatial_autocorr( layer: str | None = None, seed: int | None = None, use_raw: bool = False, - use_sparse: bool = True, + use_sparse: bool | None = None, copy: bool = False, n_jobs: int | None = None, backend: str = "loky", @@ -110,7 +110,8 @@ def spatial_autocorr( Which attribute of :class:`~anndata.AnnData` to access. See ``genes`` parameter for more information. Ignored when ``device='gpu'``. use_sparse - If `True`, use sparse matrix representation for the input matrix. Only used when ``device='gpu'``. + If `True`, use sparse matrix representation for the input matrix. + Only used when ``device='gpu'``. Defaults to `True` on GPU. %(seed_device)s %(copy)s %(parallelize_device)s diff --git a/src/squidpy/settings/_dispatch.py b/src/squidpy/settings/_dispatch.py index f831594a0..3ce68d376 100644 --- a/src/squidpy/settings/_dispatch.py +++ b/src/squidpy/settings/_dispatch.py @@ -9,7 +9,7 @@ from collections.abc import Callable from typing import Any, Literal, TypeVar -from squidpy.gr._gpu import GPU_PARAM_REGISTRY, check_cpu_params, check_gpu_params +from squidpy.gr._gpu import GPU_PARAM_REGISTRY, apply_defaults, check_cpu_params, check_gpu_params from squidpy.settings._settings import settings __all__ = ["gpu_dispatch"] @@ -99,30 +99,35 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # Handle **kwargs: unpack instead of passing as kwargs=dict extra_kwargs = all_args.pop("kwargs", {}) + # Get registry for this function registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) if _resolve_device(device) == "gpu": - # Collect CPU-only param values and check them (warn if non-default) + # Collect CPU-only param values and check them (error if user provided) cpu_only_values = {k: all_args.pop(k) for k in list(all_args) if k in registry["cpu_only"]} - cpu_only_values.update( - {k: extra_kwargs.pop(k) for k in list(extra_kwargs) if k in registry["cpu_only"]} - ) + cpu_only_values.update({k: extra_kwargs.pop(k) for k in list(extra_kwargs) if k in registry["cpu_only"]}) check_gpu_params(func_name, **cpu_only_values) + # Apply defaults for GPU-only params that are None + apply_defaults(func_name, all_args, "gpu") + # Import and call GPU function module = importlib.import_module(gpu_module) gpu_func = getattr(module, gpu_func_name or func_name) return gpu_func(**all_args, **extra_kwargs) - # CPU path: check gpu_only params (error if non-default), then filter them out + # CPU path: check gpu_only params (error if user provided), then filter them out gpu_only_values = {k: all_args.pop(k) for k in list(all_args) if k in registry["gpu_only"]} gpu_only_values.update({k: extra_kwargs.pop(k) for k in list(extra_kwargs) if k in registry["gpu_only"]}) check_cpu_params(func_name, **gpu_only_values) + # Apply defaults for CPU-only params that are None + apply_defaults(func_name, all_args, "cpu") + return func(**all_args, **extra_kwargs) return wrapper # type: ignore[return-value] diff --git a/tests/test_settings.py b/tests/test_settings.py index 8fa8ccbfc..be96e76f4 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -127,12 +127,12 @@ def documented_func(x, device=None): assert "Original docstring." in documented_func.__doc__ assert "GPU acceleration" in documented_func.__doc__ - def test_cpu_only_params_filtered_on_gpu(self, mock_gpu_module): - """Test CPU-only params are filtered out on GPU.""" + def test_cpu_only_params_error_on_gpu_if_provided(self, mock_gpu_module): + """Test CPU-only params raise error on GPU if user provided a value.""" mock_module, mock_adapter = mock_gpu_module registry = { "my_func": { - "cpu_only": {"n_jobs": GpuParamSpec(None)}, + "cpu_only": {"n_jobs": GpuParamSpec(1)}, "gpu_only": {}, } } @@ -141,18 +141,25 @@ def test_cpu_only_params_filtered_on_gpu(self, mock_gpu_module): def my_func(x, n_jobs=None, device=None): return "cpu_result" - with ( - patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), - patch("importlib.import_module", return_value=mock_module), - patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry), - ): - my_func(42, n_jobs=4, device="gpu") - - # n_jobs should be filtered out - mock_adapter.assert_called_once_with(x=42) - - def test_gpu_only_params_error_on_cpu_if_non_default(self): - """Test GPU-only params raise error on CPU if non-default.""" + with patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry): + # Not provided (None) - should work + with ( + patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), + patch("importlib.import_module", return_value=mock_module), + ): + my_func(42, device="gpu") + mock_adapter.assert_called_once_with(x=42) + + # Provided a value - should error + with pytest.raises(ValueError, match="n_jobs.*only supported on CPU"): + with ( + patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), + patch("importlib.import_module", return_value=mock_module), + ): + my_func(42, n_jobs=4, device="gpu") + + def test_gpu_only_params_error_on_cpu_if_provided(self): + """Test GPU-only params raise error on CPU if user provided a value.""" registry = { "my_func": { "cpu_only": {}, @@ -161,13 +168,13 @@ def test_gpu_only_params_error_on_cpu_if_non_default(self): } @gpu_dispatch(gpu_module="test_module") - def my_func(x, use_sparse=True, device=None): + def my_func(x, use_sparse=None, device=None): return "cpu_result" with patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry): - # Default value works - assert my_func(42, use_sparse=True, device="cpu") == "cpu_result" + # Not provided (None) - should work + assert my_func(42, device="cpu") == "cpu_result" - # Non-default raises error + # Provided a value - should error with pytest.raises(ValueError, match="use_sparse.*only supported on GPU"): - my_func(42, use_sparse=False, device="cpu") + my_func(42, use_sparse=True, device="cpu") From ab5af393389c2d4a352beebb6666fc828bf7dcdf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Jan 2026 08:29:06 +0000 Subject: [PATCH 25/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_gpu.py | 4 +++- src/squidpy/settings/_dispatch.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py index 3f9558d6d..53d1d3584 100644 --- a/src/squidpy/gr/_gpu.py +++ b/src/squidpy/gr/_gpu.py @@ -107,7 +107,9 @@ def check_gpu_params(func_name: str, **cpu_only_values: Any) -> None: raise ValueError(msg.format(name=name, value=value)) else: # User explicitly provided a value for a CPU-only param on GPU - msg = spec.message or f"{name}={value!r} is only supported on CPU. Use device='cpu' or remove this argument." + msg = ( + spec.message or f"{name}={value!r} is only supported on CPU. Use device='cpu' or remove this argument." + ) raise ValueError(msg.format(name=name, value=value)) diff --git a/src/squidpy/settings/_dispatch.py b/src/squidpy/settings/_dispatch.py index 3ce68d376..d2d4450c2 100644 --- a/src/squidpy/settings/_dispatch.py +++ b/src/squidpy/settings/_dispatch.py @@ -99,14 +99,15 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # Handle **kwargs: unpack instead of passing as kwargs=dict extra_kwargs = all_args.pop("kwargs", {}) - # Get registry for this function registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) if _resolve_device(device) == "gpu": # Collect CPU-only param values and check them (error if user provided) cpu_only_values = {k: all_args.pop(k) for k in list(all_args) if k in registry["cpu_only"]} - cpu_only_values.update({k: extra_kwargs.pop(k) for k in list(extra_kwargs) if k in registry["cpu_only"]}) + cpu_only_values.update( + {k: extra_kwargs.pop(k) for k in list(extra_kwargs) if k in registry["cpu_only"]} + ) check_gpu_params(func_name, **cpu_only_values) From c0aa5daa7ce7a0828383c173b0a8414903559c52 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Tue, 27 Jan 2026 11:40:11 +0100 Subject: [PATCH 26/68] save wip --- src/squidpy/gr/_gpu.py | 12 +-- src/squidpy/gr/_ppatterns.py | 18 +++-- src/squidpy/settings/_dispatch.py | 27 ++++--- tests/test_settings.py | 129 ++++++++++++++++++++++++++++-- 4 files changed, 151 insertions(+), 35 deletions(-) diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py index 53d1d3584..0319161d9 100644 --- a/src/squidpy/gr/_gpu.py +++ b/src/squidpy/gr/_gpu.py @@ -1,7 +1,7 @@ """GPU parameter registry for squidpy.gr functions. -Defines which parameters are CPU-only (ignored on GPU) and GPU-only (ignored on CPU). -The gpu_dispatch decorator uses this registry to automatically handle parameter filtering. +Defines which parameters are CPU-only (error on GPU if provided) and GPU-only (error on CPU if provided). +The gpu_dispatch decorator uses this registry to automatically handle parameter validation and filtering. """ from __future__ import annotations @@ -23,10 +23,10 @@ class GpuParamSpec: def _attr_validator(value: Any) -> str | None: - """Special validator for attr param - only warn if not 'X'.""" + """Validator for attr param - error if not 'X' on GPU.""" if value == "X": return None - return f"attr={value!r} is not supported on GPU, using attr='X'. Set device='cpu' to use other attributes." + return f"attr={value!r} is not supported on GPU. Set device='cpu' to use other attributes." # Common CPU-only param specs (reusable) @@ -38,8 +38,8 @@ def _attr_validator(value: Any) -> str | None: _SEED: dict[str, GpuParamSpec] = {"seed": GpuParamSpec(None)} # Registry: {func_name: {"cpu_only": {...}, "gpu_only": {...}}} -# - cpu_only: parameters ignored on GPU (warn if non-default, then filter out) -# - gpu_only: parameters ignored on CPU (error if non-default, pass through to GPU) +# - cpu_only: parameters only supported on CPU (error on GPU if user provided a non-None value) +# - gpu_only: parameters only supported on GPU (error on CPU if user provided a non-None value) GPU_PARAM_REGISTRY: dict[str, dict[str, dict[str, GpuParamSpec]]] = { "spatial_autocorr": { "cpu_only": { diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 8921c1674..f8d2e0944 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -56,15 +56,15 @@ def spatial_autocorr( n_perms: int | None = None, two_tailed: bool = False, corr_method: str | None = "fdr_bh", - attr: Literal["obs", "X", "obsm"] = "X", + attr: Literal["obs", "X", "obsm"] | None = None, layer: str | None = None, seed: int | None = None, use_raw: bool = False, use_sparse: bool | None = None, copy: bool = False, n_jobs: int | None = None, - backend: str = "loky", - show_progress_bar: bool = True, + backend: str | None = None, + show_progress_bar: bool | None = None, device: Literal["cpu", "gpu"] | None = None, ) -> pd.DataFrame | None: """ @@ -141,6 +141,10 @@ def spatial_autocorr( adata = adata.table _assert_connectivity_key(adata, connectivity_key) + # Apply defaults for CPU-only params + if attr is None: + attr = "X" + def extract_X(adata: AnnData, genes: str | Sequence[str] | None) -> tuple[NDArrayA | spmatrix, Sequence[Any]]: if genes is None: if "highly_variable" in adata.var: @@ -360,8 +364,8 @@ def co_occurrence( copy: bool = False, n_splits: int | None = None, n_jobs: int | None = None, - backend: str = "loky", - show_progress_bar: bool = True, + backend: str | None = None, + show_progress_bar: bool | None = None, device: Literal["cpu", "gpu"] | None = None, ) -> tuple[NDArrayA, NDArrayA] | None: """ @@ -416,10 +420,8 @@ def co_occurrence( spatial_y = spatial[:, 1] # Compute co-occurrence probabilities using the fast numba routine. + start = logg.info(f"Calculating co-occurrence probabilities for `{len(interval)}` intervals") out = _co_occurrence_helper(spatial_x, spatial_y, interval, labs) - start = logg.info( - f"Calculating co-occurrence probabilities for `{len(interval)}` intervals using `{n_jobs}` core(s) and `{n_splits}` splits" - ) if copy: logg.info("Finish", time=start) diff --git a/src/squidpy/settings/_dispatch.py b/src/squidpy/settings/_dispatch.py index d2d4450c2..84876776d 100644 --- a/src/squidpy/settings/_dispatch.py +++ b/src/squidpy/settings/_dispatch.py @@ -31,16 +31,16 @@ def _resolve_device(device: Literal["auto", "cpu", "gpu"] | None) -> Literal["cp return "gpu" if settings.gpu_available() else "cpu" -def _make_gpu_note(func_name: str, indent: str = "") -> str: +def _make_gpu_note(func_name: str, gpu_module: str, indent: str = "") -> str: lines = [ ".. note::", " This function supports GPU acceleration via :doc:`rapids_singlecell `.", - f" See :func:`rapids_singlecell.gr.{func_name}` for the GPU implementation.", + f" See :func:`{gpu_module}.{func_name}` for the GPU implementation.", ] return "\n".join(indent + line for line in lines) -def _inject_gpu_note(doc: str | None, func_name: str) -> str | None: +def _inject_gpu_note(doc: str | None, func_name: str, gpu_module: str) -> str | None: """Inject GPU note into docstring before the Parameters section.""" if doc is None: return None @@ -49,12 +49,12 @@ def _inject_gpu_note(doc: str | None, func_name: str) -> str | None: match = re.search(r"\n([ \t]*)Parameters\s*\n\s*-+", doc) if match: indent = match.group(1) # Capture only the spaces/tabs before Parameters - gpu_note = _make_gpu_note(func_name, indent) + gpu_note = _make_gpu_note(func_name, gpu_module, indent) insert_pos = match.start() return doc[:insert_pos] + "\n\n" + gpu_note + "\n" + doc[insert_pos:] # Fallback: append at the end - return doc + "\n\n" + _make_gpu_note(func_name) + return doc + "\n\n" + _make_gpu_note(func_name, gpu_module) def gpu_dispatch( @@ -79,9 +79,10 @@ def gpu_dispatch( def decorator(func: F) -> F: func_name = func.__name__ + _gpu_func_name = gpu_func_name or func_name # Inject GPU note into docstring - func.__doc__ = _inject_gpu_note(func.__doc__, func_name) + func.__doc__ = _inject_gpu_note(func.__doc__, _gpu_func_name, gpu_module) @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: @@ -96,8 +97,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: device = all_args.pop("device", None) - # Handle **kwargs: unpack instead of passing as kwargs=dict - extra_kwargs = all_args.pop("kwargs", {}) + # Handle **kwargs from function signature: unpack instead of passing as kwargs=dict + variadic_kwargs = all_args.pop("kwargs", {}) # Get registry for this function registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) @@ -106,7 +107,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # Collect CPU-only param values and check them (error if user provided) cpu_only_values = {k: all_args.pop(k) for k in list(all_args) if k in registry["cpu_only"]} cpu_only_values.update( - {k: extra_kwargs.pop(k) for k in list(extra_kwargs) if k in registry["cpu_only"]} + {k: variadic_kwargs.pop(k) for k in list(variadic_kwargs) if k in registry["cpu_only"]} ) check_gpu_params(func_name, **cpu_only_values) @@ -116,20 +117,20 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # Import and call GPU function module = importlib.import_module(gpu_module) - gpu_func = getattr(module, gpu_func_name or func_name) + gpu_func = getattr(module, _gpu_func_name) - return gpu_func(**all_args, **extra_kwargs) + return gpu_func(**all_args, **variadic_kwargs) # CPU path: check gpu_only params (error if user provided), then filter them out gpu_only_values = {k: all_args.pop(k) for k in list(all_args) if k in registry["gpu_only"]} - gpu_only_values.update({k: extra_kwargs.pop(k) for k in list(extra_kwargs) if k in registry["gpu_only"]}) + gpu_only_values.update({k: variadic_kwargs.pop(k) for k in list(variadic_kwargs) if k in registry["gpu_only"]}) check_cpu_params(func_name, **gpu_only_values) # Apply defaults for CPU-only params that are None apply_defaults(func_name, all_args, "cpu") - return func(**all_args, **extra_kwargs) + return func(**all_args, **variadic_kwargs) return wrapper # type: ignore[return-value] diff --git a/tests/test_settings.py b/tests/test_settings.py index be96e76f4..4d783313a 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -137,11 +137,16 @@ def test_cpu_only_params_error_on_gpu_if_provided(self, mock_gpu_module): } } - @gpu_dispatch(gpu_module="test_module") - def my_func(x, n_jobs=None, device=None): - return "cpu_result" + # Need to patch in both modules: dispatch for registry lookup, _gpu for check functions + with ( + patch("squidpy.settings._dispatch.GPU_PARAM_REGISTRY", registry), + patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry), + ): + + @gpu_dispatch(gpu_module="test_module") + def my_func(x, n_jobs=None, device=None): + return "cpu_result" - with patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry): # Not provided (None) - should work with ( patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), @@ -167,14 +172,122 @@ def test_gpu_only_params_error_on_cpu_if_provided(self): } } - @gpu_dispatch(gpu_module="test_module") - def my_func(x, use_sparse=None, device=None): - return "cpu_result" + # Need to patch in both modules: dispatch for registry lookup, _gpu for check functions + with ( + patch("squidpy.settings._dispatch.GPU_PARAM_REGISTRY", registry), + patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry), + ): + + @gpu_dispatch(gpu_module="test_module") + def my_func(x, use_sparse=None, device=None): + return "cpu_result" - with patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry): # Not provided (None) - should work assert my_func(42, device="cpu") == "cpu_result" # Provided a value - should error with pytest.raises(ValueError, match="use_sparse.*only supported on GPU"): my_func(42, use_sparse=True, device="cpu") + + def test_function_not_in_registry_works(self): + """Test that functions not in registry work transparently.""" + calls = [] + + @gpu_dispatch() + def unregistered_func(x, device=None): + calls.append(x) + return x * 3 + + # Should work on CPU without issues + assert unregistered_func(10, device="cpu") == 30 + assert calls == [10] + + def test_gpu_silent_fallback_when_unavailable(self): + """Test GPU silently falls back to CPU when unavailable.""" + if settings.gpu_available(): + pytest.skip("GPU is available") + + calls = [] + + @gpu_dispatch() + def my_func(x, device=None): + calls.append(x) + return x + 1 + + # Should silently fall back to CPU + result = my_func(5, device="gpu") + assert result == 6 + assert calls == [5] + + def test_custom_validator_error(self, mock_gpu_module): + """Test custom validator raises appropriate error.""" + mock_module, mock_adapter = mock_gpu_module + + def my_validator(value): + if value != "allowed": + return f"value={value!r} is not allowed on GPU" + return None + + registry = { + "my_func": { + "cpu_only": {"custom_param": GpuParamSpec("allowed", validator=my_validator)}, + "gpu_only": {}, + } + } + + # Need to patch in both modules: dispatch for registry lookup, _gpu for check functions + with ( + patch("squidpy.settings._dispatch.GPU_PARAM_REGISTRY", registry), + patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry), + ): + + @gpu_dispatch(gpu_module="test_module") + def my_func(x, custom_param=None, device=None): + return "cpu_result" + + # Allowed value - should work + with ( + patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), + patch("importlib.import_module", return_value=mock_module), + ): + my_func(42, custom_param="allowed", device="gpu") + + # Not allowed value - should error + with pytest.raises(ValueError, match="value='bad' is not allowed on GPU"): + with ( + patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), + patch("importlib.import_module", return_value=mock_module), + ): + my_func(42, custom_param="bad", device="gpu") + + def test_docstring_uses_custom_gpu_module(self): + """Test that docstring GPU note uses the specified gpu_module.""" + + @gpu_dispatch(gpu_module="custom.module.path") + def my_func(x, device=None): + """My function. + + Parameters + ---------- + x + Input. + """ + return x + + assert "custom.module.path.my_func" in my_func.__doc__ + + def test_docstring_uses_custom_gpu_func_name(self): + """Test that docstring GPU note uses the specified gpu_func_name.""" + + @gpu_dispatch(gpu_module="some.module", gpu_func_name="different_name") + def my_func(x, device=None): + """My function. + + Parameters + ---------- + x + Input. + """ + return x + + assert "some.module.different_name" in my_func.__doc__ From 8d03b5bc250dcb914ecbaf4807d20acb9419a43a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:40:25 +0000 Subject: [PATCH 27/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/settings/_dispatch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/squidpy/settings/_dispatch.py b/src/squidpy/settings/_dispatch.py index 84876776d..5d67e9afa 100644 --- a/src/squidpy/settings/_dispatch.py +++ b/src/squidpy/settings/_dispatch.py @@ -123,7 +123,9 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # CPU path: check gpu_only params (error if user provided), then filter them out gpu_only_values = {k: all_args.pop(k) for k in list(all_args) if k in registry["gpu_only"]} - gpu_only_values.update({k: variadic_kwargs.pop(k) for k in list(variadic_kwargs) if k in registry["gpu_only"]}) + gpu_only_values.update( + {k: variadic_kwargs.pop(k) for k in list(variadic_kwargs) if k in registry["gpu_only"]} + ) check_cpu_params(func_name, **gpu_only_values) From 08f7af64b964c12baf482ddcfeca3557550ad0fb Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Tue, 27 Jan 2026 11:53:40 +0100 Subject: [PATCH 28/68] remove kwargs from ligrec --- src/squidpy/gr/_gpu.py | 2 +- src/squidpy/gr/_ligrec.py | 69 ++++++++++++++++++++++++++++++++------- tests/test_settings.py | 14 +++----- 3 files changed, 63 insertions(+), 22 deletions(-) diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py index 0319161d9..9556ba96c 100644 --- a/src/squidpy/gr/_gpu.py +++ b/src/squidpy/gr/_gpu.py @@ -130,7 +130,7 @@ def check_cpu_params(func_name: str, **gpu_only_values: Any) -> None: """ registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) - for name, spec in registry["gpu_only"].items(): + for name, _spec in registry["gpu_only"].items(): if name not in gpu_only_values: continue value = gpu_only_values[name] diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index 9f3dd0fdb..1c88a4b42 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -329,7 +329,9 @@ def test( copy: bool = False, key_added: str | None = None, numba_parallel: bool | None = None, - **kwargs: Any, + n_jobs: int | None = None, + backend: str | None = None, + show_progress_bar: bool | None = None, ) -> Mapping[str, pd.DataFrame] | None: """ Perform the permutation test as described in :cite:`cellphonedb`. @@ -411,10 +413,10 @@ def test( # much faster than applymap (tested on 1M interactions) interactions_ = np.vectorize(lambda g: gene_mapper[g])(interactions.values) - n_jobs = _get_n_cores(kwargs.pop("n_jobs", None)) + n_jobs_ = _get_n_cores(n_jobs) start = logg.info( f"Running `{n_perms}` permutations on `{len(interactions)}` interactions " - f"and `{len(clusters)}` cluster combinations using `{n_jobs}` core(s)" + f"and `{len(clusters)}` cluster combinations using `{n_jobs_}` core(s)" ) res = _analysis( data, @@ -423,9 +425,10 @@ def test( threshold=threshold, n_perms=n_perms, seed=seed, - n_jobs=n_jobs, + n_jobs=n_jobs_, numba_parallel=numba_parallel, - **kwargs, + backend=backend, + show_progress_bar=show_progress_bar, ) res = { "means": _create_sparse_df( @@ -648,7 +651,19 @@ def ligrec( key_added: str | None = None, gene_symbols: str | None = None, device: Literal["cpu", "gpu"] | None = None, - **kwargs: Any, + # prepare params + interactions_params: Mapping[str, Any] | None = None, + transmitter_params: Mapping[str, Any] | None = None, + receiver_params: Mapping[str, Any] | None = None, + # test params + clusters: Cluster_t | None = None, + n_perms: int | None = None, + seed: int | None = None, + alpha: float | None = None, + numba_parallel: bool | None = None, + n_jobs: int | None = None, + backend: str | None = None, + show_progress_bar: bool | None = None, ) -> Mapping[str, pd.DataFrame] | None: """ %(PT_test.full_desc)s @@ -668,18 +683,44 @@ def ligrec( """ # noqa: D400 if isinstance(adata, SpatialData): adata = adata.table + + # Apply defaults for params that don't accept None in prepare/test + if interactions_params is None: + interactions_params = MappingProxyType({}) + if transmitter_params is None: + transmitter_params = MappingProxyType({"categories": "ligand"}) + if receiver_params is None: + receiver_params = MappingProxyType({"categories": "receptor"}) + if n_perms is None: + n_perms = 1000 + if alpha is None: + alpha = 0.05 + with _genesymbols(adata, key=gene_symbols, use_raw=use_raw, make_unique=False): return ( # type: ignore[no-any-return] PermutationTest(adata, use_raw=use_raw) - .prepare(interactions, complex_policy=complex_policy, **kwargs) + .prepare( + interactions, + complex_policy=complex_policy, + interactions_params=interactions_params, + transmitter_params=transmitter_params, + receiver_params=receiver_params, + ) .test( cluster_key=cluster_key, + clusters=clusters, + n_perms=n_perms, threshold=threshold, + seed=seed, corr_method=corr_method, corr_axis=corr_axis, + alpha=alpha, copy=copy, key_added=key_added, - **kwargs, + numba_parallel=numba_parallel, + n_jobs=n_jobs, + backend=backend, + show_progress_bar=show_progress_bar, ) ) @@ -694,7 +735,8 @@ def _analysis( seed: int | None = None, n_jobs: int = 1, numba_parallel: bool | None = None, - **kwargs: Any, + backend: str | None = None, + show_progress_bar: bool | None = None, ) -> TempResult: """ Run the analysis as described in :cite:`cellphonedb`. @@ -717,8 +759,10 @@ def _analysis( Number of parallel jobs to launch. numba_parallel Whether to use :func:`numba.prange` or not. If `None`, it's determined automatically. - kwargs - Keyword arguments for :func:`squidpy._utils.parallelize`, such as ``n_jobs`` or ``backend``. + backend + Parallelization backend to use. + show_progress_bar + Whether to show the progress bar. Returns ------- @@ -761,7 +805,8 @@ def extractor(res: Sequence[TempResult]) -> TempResult: n_jobs=n_jobs, unit="permutation", extractor=extractor, - **kwargs, + backend=backend, + show_progress_bar=show_progress_bar, )( data, mean, diff --git a/tests/test_settings.py b/tests/test_settings.py index 4d783313a..9b92c5a6d 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -202,22 +202,18 @@ def unregistered_func(x, device=None): assert unregistered_func(10, device="cpu") == 30 assert calls == [10] - def test_gpu_silent_fallback_when_unavailable(self): - """Test GPU silently falls back to CPU when unavailable.""" + def test_gpu_errors_when_unavailable(self): + """Test GPU raises error when unavailable.""" if settings.gpu_available(): pytest.skip("GPU is available") - calls = [] - @gpu_dispatch() def my_func(x, device=None): - calls.append(x) return x + 1 - # Should silently fall back to CPU - result = my_func(5, device="gpu") - assert result == 6 - assert calls == [5] + # Should raise error when GPU requested but unavailable + with pytest.raises(RuntimeError, match="GPU unavailable"): + my_func(5, device="gpu") def test_custom_validator_error(self, mock_gpu_module): """Test custom validator raises appropriate error.""" From 48851953072899ec130f62f5e75bb36001148fdf Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Tue, 27 Jan 2026 12:10:23 +0100 Subject: [PATCH 29/68] refactor and formatting --- src/squidpy/__init__.py | 2 +- .../{settings => _settings}/__init__.py | 4 ++-- .../{settings => _settings}/_dispatch.py | 2 +- .../{settings => _settings}/_settings.py | 0 src/squidpy/gr/_ligrec.py | 2 +- src/squidpy/gr/_ppatterns.py | 6 +---- tests/test_gpu.py | 2 +- tests/test_settings.py | 22 +++++++++---------- 8 files changed, 18 insertions(+), 22 deletions(-) rename src/squidpy/{settings => _settings}/__init__.py (50%) rename src/squidpy/{settings => _settings}/_dispatch.py (99%) rename src/squidpy/{settings => _settings}/_settings.py (100%) diff --git a/src/squidpy/__init__.py b/src/squidpy/__init__.py index 7313b6f30..1aaa80053 100644 --- a/src/squidpy/__init__.py +++ b/src/squidpy/__init__.py @@ -4,7 +4,7 @@ from importlib.metadata import PackageMetadata from squidpy import datasets, experimental, gr, im, pl, read, tl -from squidpy.settings import settings +from squidpy._settings import settings try: md: PackageMetadata = metadata.metadata(__name__) diff --git a/src/squidpy/settings/__init__.py b/src/squidpy/_settings/__init__.py similarity index 50% rename from src/squidpy/settings/__init__.py rename to src/squidpy/_settings/__init__.py index aa1bf6afe..5c0dbd920 100644 --- a/src/squidpy/settings/__init__.py +++ b/src/squidpy/_settings/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations -from squidpy.settings._dispatch import gpu_dispatch -from squidpy.settings._settings import DeviceType, settings +from squidpy._settings._dispatch import gpu_dispatch +from squidpy._settings._settings import DeviceType, settings __all__ = ["settings", "DeviceType", "gpu_dispatch"] diff --git a/src/squidpy/settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py similarity index 99% rename from src/squidpy/settings/_dispatch.py rename to src/squidpy/_settings/_dispatch.py index 5d67e9afa..b2dabedc3 100644 --- a/src/squidpy/settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -9,8 +9,8 @@ from collections.abc import Callable from typing import Any, Literal, TypeVar +from squidpy._settings._settings import settings from squidpy.gr._gpu import GPU_PARAM_REGISTRY, apply_defaults, check_cpu_params, check_gpu_params -from squidpy.settings._settings import settings __all__ = ["gpu_dispatch"] diff --git a/src/squidpy/settings/_settings.py b/src/squidpy/_settings/_settings.py similarity index 100% rename from src/squidpy/settings/_settings.py rename to src/squidpy/_settings/_settings.py diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index 1c88a4b42..d74269fe1 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -20,6 +20,7 @@ from squidpy._constants._constants import ComplexPolicy, CorrAxis from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs +from squidpy._settings import gpu_dispatch from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize from squidpy.gr._utils import ( _assert_categorical_obs, @@ -29,7 +30,6 @@ _genesymbols, _save_data, ) -from squidpy.settings import gpu_dispatch __all__ = ["ligrec", "PermutationTest"] diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index f8d2e0944..a6733c47d 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -23,6 +23,7 @@ from squidpy._constants._constants import SpatialAutocorr from squidpy._constants._pkg_constants import Key from squidpy._docs import d, inject_docs +from squidpy._settings import gpu_dispatch from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize from squidpy.gr._utils import ( _assert_categorical_obs, @@ -31,7 +32,6 @@ _assert_spatial_basis, _save_data, ) -from squidpy.settings import gpu_dispatch __all__ = ["spatial_autocorr", "co_occurrence"] @@ -141,10 +141,6 @@ def spatial_autocorr( adata = adata.table _assert_connectivity_key(adata, connectivity_key) - # Apply defaults for CPU-only params - if attr is None: - attr = "X" - def extract_X(adata: AnnData, genes: str | Sequence[str] | None) -> tuple[NDArrayA | spmatrix, Sequence[Any]]: if genes is None: if "highly_variable" in adata.var: diff --git a/tests/test_gpu.py b/tests/test_gpu.py index 7afc866aa..6147e020b 100644 --- a/tests/test_gpu.py +++ b/tests/test_gpu.py @@ -6,7 +6,7 @@ import pytest import squidpy as sq -from squidpy.settings import settings +from squidpy._settings import settings # Skip all tests in this module if GPU is not available pytestmark = pytest.mark.skipif( diff --git a/tests/test_settings.py b/tests/test_settings.py index 9b92c5a6d..0422f25fc 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,4 +1,4 @@ -"""Tests for squidpy.settings module.""" +"""Tests for squidpy._settings module.""" from __future__ import annotations @@ -6,8 +6,8 @@ import pytest +from squidpy._settings import gpu_dispatch, settings from squidpy.gr._gpu import GpuParamSpec -from squidpy.settings import gpu_dispatch, settings class TestSettings: @@ -83,7 +83,7 @@ def my_func(x, device=None): return "cpu_result" with ( - patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), + patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", {"my_func": {"cpu_only": {}, "gpu_only": {}}}), ): @@ -101,7 +101,7 @@ def my_func(x, device=None): return "cpu_result" with ( - patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), + patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", {"my_func": {"cpu_only": {}, "gpu_only": {}}}), ): @@ -139,7 +139,7 @@ def test_cpu_only_params_error_on_gpu_if_provided(self, mock_gpu_module): # Need to patch in both modules: dispatch for registry lookup, _gpu for check functions with ( - patch("squidpy.settings._dispatch.GPU_PARAM_REGISTRY", registry), + patch("squidpy._settings._dispatch.GPU_PARAM_REGISTRY", registry), patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry), ): @@ -149,7 +149,7 @@ def my_func(x, n_jobs=None, device=None): # Not provided (None) - should work with ( - patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), + patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), ): my_func(42, device="gpu") @@ -158,7 +158,7 @@ def my_func(x, n_jobs=None, device=None): # Provided a value - should error with pytest.raises(ValueError, match="n_jobs.*only supported on CPU"): with ( - patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), + patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), ): my_func(42, n_jobs=4, device="gpu") @@ -174,7 +174,7 @@ def test_gpu_only_params_error_on_cpu_if_provided(self): # Need to patch in both modules: dispatch for registry lookup, _gpu for check functions with ( - patch("squidpy.settings._dispatch.GPU_PARAM_REGISTRY", registry), + patch("squidpy._settings._dispatch.GPU_PARAM_REGISTRY", registry), patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry), ): @@ -233,7 +233,7 @@ def my_validator(value): # Need to patch in both modules: dispatch for registry lookup, _gpu for check functions with ( - patch("squidpy.settings._dispatch.GPU_PARAM_REGISTRY", registry), + patch("squidpy._settings._dispatch.GPU_PARAM_REGISTRY", registry), patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry), ): @@ -243,7 +243,7 @@ def my_func(x, custom_param=None, device=None): # Allowed value - should work with ( - patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), + patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), ): my_func(42, custom_param="allowed", device="gpu") @@ -251,7 +251,7 @@ def my_func(x, custom_param=None, device=None): # Not allowed value - should error with pytest.raises(ValueError, match="value='bad' is not allowed on GPU"): with ( - patch("squidpy.settings._dispatch._resolve_device", return_value="gpu"), + patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), ): my_func(42, custom_param="bad", device="gpu") From ebb03c24d10a178134a6937656e2d0e374b931f9 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Tue, 27 Jan 2026 13:47:50 +0100 Subject: [PATCH 30/68] arg refactor --- src/squidpy/gr/_gpu.py | 5 ----- src/squidpy/gr/_ligrec.py | 24 ++++++------------------ 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py index 9556ba96c..0108123fb 100644 --- a/src/squidpy/gr/_gpu.py +++ b/src/squidpy/gr/_gpu.py @@ -60,12 +60,7 @@ def _attr_validator(value: Any) -> str | None: }, "ligrec": { "cpu_only": { - "clusters": GpuParamSpec(None), "numba_parallel": GpuParamSpec(None), - "transmitter_params": GpuParamSpec(None), - "receiver_params": GpuParamSpec(None), - "interactions_params": GpuParamSpec(None), - "alpha": GpuParamSpec(0.05), **_SEED, **_PARALLELIZE, }, diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index d74269fe1..c7a07b152 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -583,7 +583,6 @@ def prepare( interactions_params: Mapping[str, Any] = MappingProxyType({}), transmitter_params: Mapping[str, Any] = MappingProxyType({"categories": "ligand"}), receiver_params: Mapping[str, Any] = MappingProxyType({"categories": "receptor"}), - **_: Any, ) -> PermutationTest: """ %(PT_prepare.full_desc)s @@ -652,15 +651,16 @@ def ligrec( gene_symbols: str | None = None, device: Literal["cpu", "gpu"] | None = None, # prepare params - interactions_params: Mapping[str, Any] | None = None, - transmitter_params: Mapping[str, Any] | None = None, - receiver_params: Mapping[str, Any] | None = None, + interactions_params: Mapping[str, Any] = MappingProxyType({}), + transmitter_params: Mapping[str, Any] = MappingProxyType({"categories": "ligand"}), + receiver_params: Mapping[str, Any] = MappingProxyType({"categories": "receptor"}), # test params clusters: Cluster_t | None = None, - n_perms: int | None = None, + n_perms: int = 1000, seed: int | None = None, - alpha: float | None = None, + alpha: float = 0.05, numba_parallel: bool | None = None, + # CPU-only params (must be None to allow dispatch to detect if user provided) n_jobs: int | None = None, backend: str | None = None, show_progress_bar: bool | None = None, @@ -684,18 +684,6 @@ def ligrec( if isinstance(adata, SpatialData): adata = adata.table - # Apply defaults for params that don't accept None in prepare/test - if interactions_params is None: - interactions_params = MappingProxyType({}) - if transmitter_params is None: - transmitter_params = MappingProxyType({"categories": "ligand"}) - if receiver_params is None: - receiver_params = MappingProxyType({"categories": "receptor"}) - if n_perms is None: - n_perms = 1000 - if alpha is None: - alpha = 0.05 - with _genesymbols(adata, key=gene_symbols, use_raw=use_raw, make_unique=False): return ( # type: ignore[no-any-return] PermutationTest(adata, use_raw=use_raw) From eb58e0e33a51f3a4f28c2c54ec3300c1619939f8 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Tue, 27 Jan 2026 13:49:26 +0100 Subject: [PATCH 31/68] remove n_splits --- src/squidpy/gr/_gpu.py | 1 - src/squidpy/gr/_ppatterns.py | 4 ---- 2 files changed, 5 deletions(-) diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py index 0108123fb..1c2a8679c 100644 --- a/src/squidpy/gr/_gpu.py +++ b/src/squidpy/gr/_gpu.py @@ -53,7 +53,6 @@ def _attr_validator(value: Any) -> str | None: }, "co_occurrence": { "cpu_only": { - "n_splits": GpuParamSpec(None), **_PARALLELIZE, }, "gpu_only": {}, diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index a6733c47d..451273be6 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -358,7 +358,6 @@ def co_occurrence( spatial_key: str = Key.obsm.spatial, interval: int | NDArrayA = 50, copy: bool = False, - n_splits: int | None = None, n_jobs: int | None = None, backend: str | None = None, show_progress_bar: bool | None = None, @@ -376,9 +375,6 @@ def co_occurrence( Distances interval at which co-occurrence is computed. If :class:`int`, uniformly spaced interval of the given size will be used. %(copy)s - n_splits - Number of splits in which to divide the spatial coordinates in - :attr:`anndata.AnnData.obsm` ``['{spatial_key}']``. Ignored when ``device='gpu'``. %(parallelize_device)s %(device)s From 3568694bb9530fe36686fc7211665305f0750cc0 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Tue, 27 Jan 2026 14:02:59 +0100 Subject: [PATCH 32/68] update parallelize args --- src/squidpy/_docs.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/squidpy/_docs.py b/src/squidpy/_docs.py index 633be2124..f54d67d4a 100644 --- a/src/squidpy/_docs.py +++ b/src/squidpy/_docs.py @@ -121,11 +121,19 @@ def decorator2(obj: Any) -> Any: Whether to show the progress bar or not.""" _parallelize_device = """\ n_jobs - Number of parallel jobs. Ignored when ``device='gpu'``. + Number of parallel jobs to use. Ignored when ``device='gpu'``. + For ``backend="loky"``, the number of cores used by numba for + each job spawned by the backend will be set to 1 in order to + overcome the oversubscription issue in case you run + numba in your function to parallelize. + To set the absolute maximum number of threads in numba + for your python program, set the environment variable: + ``NUMBA_NUM_THREADS`` before running the program. backend - Parallelization backend. Ignored when ``device='gpu'``. + Parallelization backend to use. See :class:`joblib.Parallel` for available options. + Ignored when ``device='gpu'``. show_progress_bar - Whether to show the progress bar. Ignored when ``device='gpu'``.""" + Whether to show the progress bar or not. Ignored when ``device='gpu'``.""" _seed_device = """\ seed Random seed for reproducibility. Ignored when ``device='gpu'``.""" From 7b2aad367449a9b8095b9fa3858b4a4615546433 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 2 Feb 2026 10:24:53 +0100 Subject: [PATCH 33/68] apply suggestion --- src/squidpy/_settings/_dispatch.py | 84 ++++++++------ src/squidpy/gr/_gpu.py | 168 ++++++++++----------------- tests/test_settings.py | 179 ++++++++++++++--------------- 3 files changed, 200 insertions(+), 231 deletions(-) diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py index b2dabedc3..fd2bd3944 100644 --- a/src/squidpy/_settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -10,7 +10,7 @@ from typing import Any, Literal, TypeVar from squidpy._settings._settings import settings -from squidpy.gr._gpu import GPU_PARAM_REGISTRY, apply_defaults, check_cpu_params, check_gpu_params +from squidpy.gr._gpu import check_exclusive_params, get_exclusive_params __all__ = ["gpu_dispatch"] @@ -63,9 +63,11 @@ def gpu_dispatch( ) -> Callable[[F], F]: """Decorator to dispatch to GPU implementation when device='gpu'. - Uses the GPU_PARAM_REGISTRY from squidpy.gr._gpu to: - - Warn about CPU-only parameters that differ from defaults, then filter them out - - Filter out GPU-only parameters on CPU (they only affect GPU) + Automatically determines CPU-only and GPU-only parameters by comparing + function signatures. Errors if user explicitly provides a value for + an exclusive parameter on the wrong device. GPU-only + parameters are also present in the CPU signature but only to + provide a way for the user to pass the parameter to the GPU function. Also injects a GPU note into the function's docstring. @@ -84,6 +86,18 @@ def decorator(func: F) -> F: # Inject GPU note into docstring func.__doc__ = _inject_gpu_note(func.__doc__, _gpu_func_name, gpu_module) + # Cache for exclusive params (computed lazily on first GPU call) + _exclusive_params_cache: dict[str, Any] | None = None + + def _get_exclusive_params() -> tuple[dict[str, Any], dict[str, Any]]: + nonlocal _exclusive_params_cache + if _exclusive_params_cache is None: + module = importlib.import_module(gpu_module) + gpu_func = getattr(module, _gpu_func_name) + cpu_only, gpu_only = get_exclusive_params(func, gpu_func) + _exclusive_params_cache = {"cpu_only": cpu_only, "gpu_only": gpu_only} + return _exclusive_params_cache["cpu_only"], _exclusive_params_cache["gpu_only"] + @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: sig = inspect.signature(func) @@ -92,47 +106,49 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: except TypeError: return func(*args, **kwargs) + # Track what user actually provided (before defaults) + user_provided = set(bound.arguments.keys()) + bound.apply_defaults() all_args = dict(bound.arguments) device = all_args.pop("device", None) + resolved_device = _resolve_device(device) - # Handle **kwargs from function signature: unpack instead of passing as kwargs=dict - variadic_kwargs = all_args.pop("kwargs", {}) - - # Get registry for this function - registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) + if resolved_device == "gpu": + cpu_only_params, gpu_only_params = _get_exclusive_params() - if _resolve_device(device) == "gpu": - # Collect CPU-only param values and check them (error if user provided) - cpu_only_values = {k: all_args.pop(k) for k in list(all_args) if k in registry["cpu_only"]} - cpu_only_values.update( - {k: variadic_kwargs.pop(k) for k in list(variadic_kwargs) if k in registry["cpu_only"]} - ) + # Check if user explicitly provided any CPU-only params + user_provided_cpu_only = user_provided & cpu_only_params + if user_provided_cpu_only: + check_exclusive_params(func_name, user_provided_cpu_only, all_args, "gpu") - check_gpu_params(func_name, **cpu_only_values) + # Remove CPU-only params before calling GPU func + for k in cpu_only_params: + all_args.pop(k, None) - # Apply defaults for GPU-only params that are None - apply_defaults(func_name, all_args, "gpu") - - # Import and call GPU function module = importlib.import_module(gpu_module) gpu_func = getattr(module, _gpu_func_name) + return gpu_func(**all_args) - return gpu_func(**all_args, **variadic_kwargs) - - # CPU path: check gpu_only params (error if user provided), then filter them out - gpu_only_values = {k: all_args.pop(k) for k in list(all_args) if k in registry["gpu_only"]} - gpu_only_values.update( - {k: variadic_kwargs.pop(k) for k in list(variadic_kwargs) if k in registry["gpu_only"]} - ) - - check_cpu_params(func_name, **gpu_only_values) - - # Apply defaults for CPU-only params that are None - apply_defaults(func_name, all_args, "cpu") - - return func(**all_args, **variadic_kwargs) + # CPU path + try: + cpu_only_params, gpu_only_params = _get_exclusive_params() + except (ImportError, AttributeError): + # GPU module not available, just run CPU function + return func(**all_args) + + # Check if user explicitly provided any GPU-only params + user_provided_gpu_only = user_provided & gpu_only_params + if user_provided_gpu_only: + gpu_only_values = {k: all_args[k] for k in user_provided_gpu_only} + check_exclusive_params(func_name, user_provided_gpu_only, gpu_only_values, "cpu") + + # Remove GPU-only params before calling CPU func + for k in gpu_only_params: + all_args.pop(k, None) + + return func(**all_args) return wrapper # type: ignore[return-value] diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py index 1c2a8679c..ccc101332 100644 --- a/src/squidpy/gr/_gpu.py +++ b/src/squidpy/gr/_gpu.py @@ -1,25 +1,24 @@ -"""GPU parameter registry for squidpy.gr functions. +"""GPU parameter handling for squidpy.gr functions. -Defines which parameters are CPU-only (error on GPU if provided) and GPU-only (error on CPU if provided). -The gpu_dispatch decorator uses this registry to automatically handle parameter validation and filtering. +Automatically determines CPU-only and GPU-only parameters by introspecting function signatures. +Only special cases (custom validators) need explicit registry entries. """ from __future__ import annotations +import inspect from collections.abc import Callable from dataclasses import dataclass from typing import Any -__all__ = ["GPU_PARAM_REGISTRY", "GpuParamSpec", "check_gpu_params", "check_cpu_params", "apply_defaults"] +__all__ = ["SPECIAL_PARAM_REGISTRY", "ParamSpec", "check_exclusive_params", "get_exclusive_params"] @dataclass -class GpuParamSpec: - """Specification for a parameter's GPU compatibility.""" +class ParamSpec: + """Specification for a parameter with custom validation.""" - default: Any - message: str | None = None - validator: Callable[[Any], str | None] | None = None + validate_fn: Callable[[Any], str | None] def _attr_validator(value: Any) -> str | None: @@ -29,131 +28,88 @@ def _attr_validator(value: Any) -> str | None: return f"attr={value!r} is not supported on GPU. Set device='cpu' to use other attributes." -# Common CPU-only param specs (reusable) -_PARALLELIZE: dict[str, GpuParamSpec] = { - "n_jobs": GpuParamSpec(None), - "backend": GpuParamSpec("loky"), - "show_progress_bar": GpuParamSpec(True), -} -_SEED: dict[str, GpuParamSpec] = {"seed": GpuParamSpec(None)} - -# Registry: {func_name: {"cpu_only": {...}, "gpu_only": {...}}} -# - cpu_only: parameters only supported on CPU (error on GPU if user provided a non-None value) -# - gpu_only: parameters only supported on GPU (error on CPU if user provided a non-None value) -GPU_PARAM_REGISTRY: dict[str, dict[str, dict[str, GpuParamSpec]]] = { +# Minimal registry: only for params that need custom validators +# Format: {func_name: {"cpu_only": {param: ParamSpec}, "gpu_only": {param: ParamSpec}}} +SPECIAL_PARAM_REGISTRY: dict[str, dict[str, dict[str, ParamSpec]]] = { "spatial_autocorr": { "cpu_only": { - "attr": GpuParamSpec("X", validator=_attr_validator), - **_SEED, - **_PARALLELIZE, - }, - "gpu_only": { - "use_sparse": GpuParamSpec(True), - }, - }, - "co_occurrence": { - "cpu_only": { - **_PARALLELIZE, - }, - "gpu_only": {}, - }, - "ligrec": { - "cpu_only": { - "numba_parallel": GpuParamSpec(None), - **_SEED, - **_PARALLELIZE, + "attr": ParamSpec(validate_fn=_attr_validator), }, "gpu_only": {}, }, } -def check_gpu_params(func_name: str, **cpu_only_values: Any) -> None: - """Check CPU-only params on GPU, raise error if user provided a value. +def get_exclusive_params( + cpu_func: Callable[..., Any], gpu_func: Callable[..., Any] +) -> tuple[set[str], set[str]]: + """Get CPU-only and GPU-only params by comparing function signatures. Parameters ---------- - func_name - Name of the function in GPU_PARAM_REGISTRY. - **cpu_only_values - CPU-only parameter values to check. None means not provided by user. - - Raises - ------ - ValueError - If a CPU-only parameter was explicitly provided (not None) on GPU. + cpu_func + The CPU implementation function. + gpu_func + The GPU implementation function. + + Returns + ------- + Tuple of (cpu_only_params, gpu_only_params) as sets of param names. """ - registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) + cpu_sig = inspect.signature(cpu_func) + gpu_sig = inspect.signature(gpu_func) - for name, spec in registry["cpu_only"].items(): - if name not in cpu_only_values: - continue - value = cpu_only_values[name] + cpu_params = set(cpu_sig.parameters.keys()) + gpu_params = set(gpu_sig.parameters.keys()) - # None means user didn't provide a value - that's fine - if value is None: - continue + # CPU-only: in CPU sig but not in GPU sig (excluding 'device' which is handled separately) + cpu_only = cpu_params - gpu_params - {"device"} - # Use custom validator if provided - if spec.validator: - msg = spec.validator(value) - if msg: - raise ValueError(msg.format(name=name, value=value)) - else: - # User explicitly provided a value for a CPU-only param on GPU - msg = ( - spec.message or f"{name}={value!r} is only supported on CPU. Use device='cpu' or remove this argument." - ) - raise ValueError(msg.format(name=name, value=value)) + # GPU-only: in GPU sig but not in CPU sig + gpu_only = gpu_params - cpu_params + + return cpu_only, gpu_only -def check_cpu_params(func_name: str, **gpu_only_values: Any) -> None: - """Check GPU-only params on CPU, raise error if user provided a value. +def check_exclusive_params( + func_name: str, + user_provided_exclusive: set[str], + param_values: dict[str, Any], + target_device: str, +) -> None: + """Check exclusive params, raise error if user explicitly provided any. Parameters ---------- func_name - Name of the function in GPU_PARAM_REGISTRY. - **gpu_only_values - GPU-only parameter values to check. None means not provided by user. + Name of the function (for registry lookup). + user_provided_exclusive + Set of param names that user explicitly provided AND are exclusive to other device. + param_values + All argument values (for error messages and custom validators). + target_device + The device being used ('cpu' or 'gpu'). Raises ------ ValueError - If a GPU-only parameter was explicitly provided (not None) on CPU. + If user explicitly provided an exclusive parameter. """ - registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) + other_device = "gpu" if target_device == "cpu" else "cpu" + registry_key = "gpu_only" if target_device == "cpu" else "cpu_only" + registry = SPECIAL_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) - for name, _spec in registry["gpu_only"].items(): - if name not in gpu_only_values: - continue - value = gpu_only_values[name] + for name in user_provided_exclusive: + value = param_values.get(name) - # None means user didn't provide a value - that's fine - if value is None: + # Check special validate_fn first (they may allow certain values) + if name in registry[registry_key]: + spec = registry[registry_key][name] + msg = spec.validate_fn(value) + if msg: + raise ValueError(msg) continue - # User explicitly provided a value for a GPU-only param on CPU - msg = f"{name}={value!r} is only supported on GPU. Use device='gpu' or remove this argument." + # User explicitly provided an exclusive param - error + msg = f"{name}={value!r} is only supported on {other_device.upper()}. Use device={other_device!r} or remove this argument." raise ValueError(msg) - - -def apply_defaults(func_name: str, args: dict[str, Any], target: str) -> None: - """Apply registry defaults for params that are None. - - Parameters - ---------- - func_name - Name of the function in GPU_PARAM_REGISTRY. - args - Arguments dict to modify in place. - target - Either 'cpu' or 'gpu' - which defaults to apply. - """ - registry = GPU_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) - - # Apply defaults for the target's own params - param_key = "cpu_only" if target == "cpu" else "gpu_only" - for name, spec in registry[param_key].items(): - if name in args and args[name] is None: - args[name] = spec.default diff --git a/tests/test_settings.py b/tests/test_settings.py index 0422f25fc..7281310de 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -7,7 +7,7 @@ import pytest from squidpy._settings import gpu_dispatch, settings -from squidpy.gr._gpu import GpuParamSpec +from squidpy.gr._gpu import SPECIAL_PARAM_REGISTRY, ParamSpec class TestSettings: @@ -39,14 +39,6 @@ def test_set_device_gpu_without_rsc(self): class TestGpuDispatch: """Test the gpu_dispatch decorator.""" - @pytest.fixture - def mock_gpu_module(self): - """Create a mock GPU module with adapter function.""" - mock_adapter = MagicMock(return_value="gpu_result") - mock_module = MagicMock() - mock_module.my_func = mock_adapter - return mock_module, mock_adapter - def test_cpu_path(self): """Test CPU device calls original function.""" calls = [] @@ -74,9 +66,15 @@ def my_func(x, device=None): assert my_func(5, device="auto") == 10 assert calls == [5] - def test_gpu_path(self, mock_gpu_module): + def test_gpu_path(self): """Test GPU device dispatches to GPU module.""" - mock_module, mock_adapter = mock_gpu_module + mock_module = MagicMock() + + # Must use real function for signature introspection + def gpu_my_func(x): + return "gpu_result" + + mock_module.my_func = gpu_my_func @gpu_dispatch(gpu_module="test_module") def my_func(x, device=None): @@ -85,16 +83,18 @@ def my_func(x, device=None): with ( patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), - patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", {"my_func": {"cpu_only": {}, "gpu_only": {}}}), ): assert my_func(42, device="gpu") == "gpu_result" - mock_adapter.assert_called_once_with(x=42) - - def test_custom_gpu_func_name(self, mock_gpu_module): + def test_custom_gpu_func_name(self): """Test custom GPU function name.""" - mock_module, mock_adapter = mock_gpu_module - mock_module.custom_name = mock_adapter + mock_module = MagicMock() + + # Must use real function for signature introspection + def custom_name(x): + return "gpu_result" + + mock_module.custom_name = custom_name @gpu_dispatch(gpu_module="test_module", gpu_func_name="custom_name") def my_func(x, device=None): @@ -103,12 +103,9 @@ def my_func(x, device=None): with ( patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), - patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", {"my_func": {"cpu_only": {}, "gpu_only": {}}}), ): assert my_func(42, device="gpu") == "gpu_result" - mock_adapter.assert_called_once_with(x=42) - def test_preserves_function_metadata(self): """Test decorator preserves function name and injects GPU note.""" @@ -127,79 +124,81 @@ def documented_func(x, device=None): assert "Original docstring." in documented_func.__doc__ assert "GPU acceleration" in documented_func.__doc__ - def test_cpu_only_params_error_on_gpu_if_provided(self, mock_gpu_module): - """Test CPU-only params raise error on GPU if user provided a value.""" - mock_module, mock_adapter = mock_gpu_module - registry = { - "my_func": { - "cpu_only": {"n_jobs": GpuParamSpec(1)}, - "gpu_only": {}, - } - } + def test_cpu_only_params_error_on_gpu_if_provided(self): + """Test CPU-only params raise error on GPU if user explicitly provided them.""" + mock_module = MagicMock() + + # GPU function without n_jobs param (CPU-only) + def gpu_my_func(x): + return "gpu_result" + + mock_module.my_func = gpu_my_func + + @gpu_dispatch(gpu_module="test_module") + def my_func(x, n_jobs=1, device=None): + return "cpu_result" - # Need to patch in both modules: dispatch for registry lookup, _gpu for check functions with ( - patch("squidpy._settings._dispatch.GPU_PARAM_REGISTRY", registry), - patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry), + patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), + patch("importlib.import_module", return_value=mock_module), ): + # Not provided - should work + assert my_func(42, device="gpu") == "gpu_result" - @gpu_dispatch(gpu_module="test_module") - def my_func(x, n_jobs=None, device=None): - return "cpu_result" - - # Not provided (None) - should work - with ( - patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), - patch("importlib.import_module", return_value=mock_module), - ): - my_func(42, device="gpu") - mock_adapter.assert_called_once_with(x=42) + # Explicitly provided (even if same as default) - should error + with pytest.raises(ValueError, match="n_jobs.*only supported on CPU"): + my_func(42, n_jobs=1, device="gpu") - # Provided a value - should error + # Explicitly provided with different value - should also error with pytest.raises(ValueError, match="n_jobs.*only supported on CPU"): - with ( - patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), - patch("importlib.import_module", return_value=mock_module), - ): - my_func(42, n_jobs=4, device="gpu") + my_func(42, n_jobs=4, device="gpu") def test_gpu_only_params_error_on_cpu_if_provided(self): - """Test GPU-only params raise error on CPU if user provided a value.""" - registry = { - "my_func": { - "cpu_only": {}, - "gpu_only": {"use_sparse": GpuParamSpec(True)}, - } - } + """Test GPU-only params raise error on CPU if user explicitly provided them. - # Need to patch in both modules: dispatch for registry lookup, _gpu for check functions - with ( - patch("squidpy._settings._dispatch.GPU_PARAM_REGISTRY", registry), - patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry), - ): + GPU-only params are those in GPU signature but NOT in CPU signature. + If user tries to pass a GPU-only param on CPU, Python raises TypeError + (unexpected keyword argument) unless the CPU func accepts **kwargs. + """ + mock_module = MagicMock() - @gpu_dispatch(gpu_module="test_module") - def my_func(x, use_sparse=None, device=None): - return "cpu_result" + # GPU func has gpu_batch_size (GPU-only, not in CPU sig) + def gpu_my_func(x, gpu_batch_size=1000): + return "gpu_result" - # Not provided (None) - should work + mock_module.my_func = gpu_my_func + + # CPU func does NOT have gpu_batch_size + @gpu_dispatch(gpu_module="test_module") + def my_func(x, device=None): + return "cpu_result" + + with patch("importlib.import_module", return_value=mock_module): + # Not provided - should work assert my_func(42, device="cpu") == "cpu_result" - # Provided a value - should error - with pytest.raises(ValueError, match="use_sparse.*only supported on GPU"): - my_func(42, use_sparse=True, device="cpu") + # GPU-only param on CPU - Python raises TypeError (not in signature) + with pytest.raises(TypeError, match="unexpected keyword argument"): + my_func(42, gpu_batch_size=500, device="cpu") - def test_function_not_in_registry_works(self): - """Test that functions not in registry work transparently.""" + def test_function_with_no_exclusive_params(self): + """Test that functions with matching signatures work transparently.""" calls = [] + mock_module = MagicMock() - @gpu_dispatch() - def unregistered_func(x, device=None): + # GPU func has same signature + def gpu_func(x): + return "gpu_result" + + mock_module.my_func = gpu_func + + @gpu_dispatch(gpu_module="test_module") + def my_func(x, device=None): calls.append(x) return x * 3 # Should work on CPU without issues - assert unregistered_func(10, device="cpu") == 30 + assert my_func(10, device="cpu") == 30 assert calls == [10] def test_gpu_errors_when_unavailable(self): @@ -215,45 +214,43 @@ def my_func(x, device=None): with pytest.raises(RuntimeError, match="GPU unavailable"): my_func(5, device="gpu") - def test_custom_validator_error(self, mock_gpu_module): + def test_custom_validator_error(self): """Test custom validator raises appropriate error.""" - mock_module, mock_adapter = mock_gpu_module + mock_module = MagicMock() def my_validator(value): if value != "allowed": return f"value={value!r} is not allowed on GPU" return None + # GPU func without custom_param (CPU-only with validator) + def gpu_my_func(x): + return "gpu_result" + + mock_module.my_func = gpu_my_func + registry = { "my_func": { - "cpu_only": {"custom_param": GpuParamSpec("allowed", validator=my_validator)}, + "cpu_only": {"custom_param": ParamSpec(validate_fn=my_validator)}, "gpu_only": {}, } } - # Need to patch in both modules: dispatch for registry lookup, _gpu for check functions - with ( - patch("squidpy._settings._dispatch.GPU_PARAM_REGISTRY", registry), - patch("squidpy.gr._gpu.GPU_PARAM_REGISTRY", registry), - ): + with patch.dict(SPECIAL_PARAM_REGISTRY, registry): @gpu_dispatch(gpu_module="test_module") - def my_func(x, custom_param=None, device=None): + def my_func(x, custom_param="allowed", device=None): return "cpu_result" - # Allowed value - should work with ( patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), ): - my_func(42, custom_param="allowed", device="gpu") - - # Not allowed value - should error - with pytest.raises(ValueError, match="value='bad' is not allowed on GPU"): - with ( - patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), - patch("importlib.import_module", return_value=mock_module), - ): + # Allowed value - should work + assert my_func(42, custom_param="allowed", device="gpu") == "gpu_result" + + # Not allowed value - should error + with pytest.raises(ValueError, match="value='bad' is not allowed on GPU"): my_func(42, custom_param="bad", device="gpu") def test_docstring_uses_custom_gpu_module(self): From 4453d2b2f5bc3a21c5243272cecce74d19cc4855 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Feb 2026 09:25:07 +0000 Subject: [PATCH 34/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_gpu.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py index ccc101332..cb2a9cd5a 100644 --- a/src/squidpy/gr/_gpu.py +++ b/src/squidpy/gr/_gpu.py @@ -40,9 +40,7 @@ def _attr_validator(value: Any) -> str | None: } -def get_exclusive_params( - cpu_func: Callable[..., Any], gpu_func: Callable[..., Any] -) -> tuple[set[str], set[str]]: +def get_exclusive_params(cpu_func: Callable[..., Any], gpu_func: Callable[..., Any]) -> tuple[set[str], set[str]]: """Get CPU-only and GPU-only params by comparing function signatures. Parameters From 04829320f4b364c83171d6cb4c455d66add6fd4f Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 2 Feb 2026 10:31:59 +0100 Subject: [PATCH 35/68] update docs --- src/squidpy/gr/_ppatterns.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 451273be6..355c97ed9 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -56,7 +56,7 @@ def spatial_autocorr( n_perms: int | None = None, two_tailed: bool = False, corr_method: str | None = "fdr_bh", - attr: Literal["obs", "X", "obsm"] | None = None, + attr: Literal["obs", "X", "obsm"] = "X", layer: str | None = None, seed: int | None = None, use_raw: bool = False, @@ -108,7 +108,7 @@ def spatial_autocorr( Layer in :attr:`anndata.AnnData.layers` to use. If `None`, use :attr:`anndata.AnnData.X`. attr Which attribute of :class:`~anndata.AnnData` to access. See ``genes`` parameter for more information. - Ignored when ``device='gpu'``. + Can be only 'X' when effective device is 'gpu'. use_sparse If `True`, use sparse matrix representation for the input matrix. Only used when ``device='gpu'``. Defaults to `True` on GPU. @@ -137,6 +137,7 @@ def spatial_autocorr( - :attr:`anndata.AnnData.uns` ``['moranI']`` - the above mentioned dataframe, if ``mode = {sp.MORAN.s!r}``. - :attr:`anndata.AnnData.uns` ``['gearyC']`` - the above mentioned dataframe, if ``mode = {sp.GEARY.s!r}``. """ + del device, use_sparse # device and use_sparse are handled by the gpu_dispatch decorator if isinstance(adata, SpatialData): adata = adata.table _assert_connectivity_key(adata, connectivity_key) From f85479495e53691e68c80cdc8f37b43b93c16e1b Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 2 Feb 2026 10:51:07 +0100 Subject: [PATCH 36/68] fux docs --- src/squidpy/_docs.py | 32 ++++++++++++-------------------- src/squidpy/_utils.py | 10 ++++++++-- tests/graph/test_ppatterns.py | 8 ++++---- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/squidpy/_docs.py b/src/squidpy/_docs.py index f54d67d4a..f263dbb84 100644 --- a/src/squidpy/_docs.py +++ b/src/squidpy/_docs.py @@ -105,7 +105,9 @@ def decorator2(obj: Any) -> Any: _plotting_returns = """\ Nothing, just plots the figure and optionally saves the plot. """ -_parallelize = """\ +_CPU_ONLY = " Only available when ``device='cpu'``." + +_n_jobs = """\ n_jobs Number of parallel jobs to use. For ``backend="loky"``, the number of cores used by numba for @@ -114,29 +116,19 @@ def decorator2(obj: Any) -> Any: numba in your function to parallelize. To set the absolute maximum number of threads in numba for your python program, set the environment variable: - ``NUMBA_NUM_THREADS`` before running the program. + ``NUMBA_NUM_THREADS`` before running the program.""" +_backend = """\ backend - Parallelization backend to use. See :class:`joblib.Parallel` for available options. + Parallelization backend to use. See :class:`joblib.Parallel` for available options.""" +_show_progress_bar = """\ show_progress_bar Whether to show the progress bar or not.""" -_parallelize_device = """\ -n_jobs - Number of parallel jobs to use. Ignored when ``device='gpu'``. - For ``backend="loky"``, the number of cores used by numba for - each job spawned by the backend will be set to 1 in order to - overcome the oversubscription issue in case you run - numba in your function to parallelize. - To set the absolute maximum number of threads in numba - for your python program, set the environment variable: - ``NUMBA_NUM_THREADS`` before running the program. -backend - Parallelization backend to use. See :class:`joblib.Parallel` for available options. - Ignored when ``device='gpu'``. -show_progress_bar - Whether to show the progress bar or not. Ignored when ``device='gpu'``.""" -_seed_device = """\ + +_parallelize = f"{_n_jobs}\n{_backend}\n{_show_progress_bar}" +_parallelize_device = f"{_n_jobs}{_CPU_ONLY}\n{_backend}{_CPU_ONLY}\n{_show_progress_bar}{_CPU_ONLY}" +_seed_device = f"""\ seed - Random seed for reproducibility. Ignored when ``device='gpu'``.""" + Random seed for reproducibility.{_CPU_ONLY}""" _device = """\ device Device for computation: ``'cpu'``, ``'gpu'``, or ``None`` (use ``squidpy.settings.device``). diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 736c88172..ab34531c4 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -79,9 +79,9 @@ def parallelize( n_split: int | None = None, unit: str = "", use_ixs: bool = False, - backend: str = "loky", + backend: str | None = "loky", extractor: Callable[[Sequence[Any]], Any] | None = None, - show_progress_bar: bool = True, + show_progress_bar: bool | None = True, use_runner: bool = False, **_: Any, ) -> Any: @@ -119,6 +119,12 @@ def parallelize( ------- The result depending on ``callable``, ``extractor``. """ + # Apply defaults for None values (allows dispatch to pass through None) + if backend is None: + backend = "loky" + if show_progress_bar is None: + show_progress_bar = True + if show_progress_bar: try: import ipywidgets # noqa: F401 diff --git a/tests/graph/test_ppatterns.py b/tests/graph/test_ppatterns.py index 226fb2830..eb23e37f8 100644 --- a/tests/graph/test_ppatterns.py +++ b/tests/graph/test_ppatterns.py @@ -138,11 +138,11 @@ def test_co_occurrence(adata: AnnData): # @pytest.mark.parametrize(("ys", "xs"), [(10, 10), (None, None), (10, 20)]) -@pytest.mark.parametrize(("n_jobs", "n_splits"), [(1, 2), (2, 2)]) -def test_co_occurrence_reproducibility(adata: AnnData, n_jobs: int, n_splits: int): +@pytest.mark.parametrize("n_jobs", [1, 2]) +def test_co_occurrence_reproducibility(adata: AnnData, n_jobs: int): """Check co_occurrence reproducibility results.""" - arr_1, interval_1 = co_occurrence(adata, cluster_key="leiden", copy=True, n_jobs=n_jobs, n_splits=n_splits) - arr_2, interval_2 = co_occurrence(adata, cluster_key="leiden", copy=True, n_jobs=n_jobs, n_splits=n_splits) + arr_1, interval_1 = co_occurrence(adata, cluster_key="leiden", copy=True, n_jobs=n_jobs) + arr_2, interval_2 = co_occurrence(adata, cluster_key="leiden", copy=True, n_jobs=n_jobs) np.testing.assert_array_equal(sorted(interval_1), sorted(interval_2)) np.testing.assert_allclose(arr_1, arr_2) From 26e8ddc62865dc622f8c4c33b47053e42bd583e3 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 2 Feb 2026 11:07:05 +0100 Subject: [PATCH 37/68] make gpu_availible an attr --- src/squidpy/_settings/_dispatch.py | 4 ++-- src/squidpy/_settings/_settings.py | 27 +++++++++++++++------------ tests/test_gpu.py | 2 +- tests/test_settings.py | 6 +++--- 4 files changed, 21 insertions(+), 18 deletions(-) diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py index fd2bd3944..2411724fa 100644 --- a/src/squidpy/_settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -24,11 +24,11 @@ def _resolve_device(device: Literal["auto", "cpu", "gpu"] | None) -> Literal["cp if device == "cpu": return "cpu" if device == "gpu": - if not settings.gpu_available(): + if not settings.gpu_available: raise RuntimeError("GPU unavailable. Install with: pip install squidpy[gpu-cuda12]") return "gpu" # auto - return "gpu" if settings.gpu_available() else "cpu" + return "gpu" if settings.gpu_available else "cpu" def _make_gpu_note(func_name: str, gpu_module: str, indent: str = "") -> str: diff --git a/src/squidpy/_settings/_settings.py b/src/squidpy/_settings/_settings.py index fa584e6d3..98883f45a 100644 --- a/src/squidpy/_settings/_settings.py +++ b/src/squidpy/_settings/_settings.py @@ -11,9 +11,22 @@ _device_var: ContextVar[DeviceType] = ContextVar("device", default="auto") +def _check_gpu_available() -> bool: + """Check if GPU acceleration is available.""" + try: + import rapids_singlecell # noqa: F401 + + return True + except ImportError: + return False + + class SqSettings: """Global configuration for squidpy.""" + def __init__(self) -> None: + self.gpu_available: bool = _check_gpu_available() + @property def device(self) -> DeviceType: """Compute device: ``'auto'``, ``'cpu'``, or ``'gpu'``.""" @@ -23,19 +36,9 @@ def device(self) -> DeviceType: def device(self, value: DeviceType) -> None: if value not in get_args(DeviceType): raise ValueError(f"device must be one of {get_args(DeviceType)}, got {value!r}") - if value == "gpu" and not self.gpu_available(): - raise RuntimeError("GPU unavailable. Install: pip install squidpy[gpu-cuda12]") + if value == "gpu" and not self.gpu_available: + raise RuntimeError("GPU unavailable. Install: pip install squidpy[gpu-cuda12] or with [gpu-cuda11] for CUDA 11 support.") _device_var.set(value) - @staticmethod - def gpu_available() -> bool: - """Check if GPU acceleration is available.""" - try: - import rapids_singlecell # noqa: F401 - - return True - except ImportError: - return False - settings = SqSettings() diff --git a/tests/test_gpu.py b/tests/test_gpu.py index 6147e020b..a5e927b4c 100644 --- a/tests/test_gpu.py +++ b/tests/test_gpu.py @@ -10,7 +10,7 @@ # Skip all tests in this module if GPU is not available pytestmark = pytest.mark.skipif( - not settings.gpu_available(), + not settings.gpu_available, reason="GPU tests require rapids-singlecell to be installed", ) diff --git a/tests/test_settings.py b/tests/test_settings.py index 7281310de..c966d79e2 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -31,7 +31,7 @@ def test_set_device_invalid(self): def test_set_device_gpu_without_rsc(self): """Test that setting device to 'gpu' without rapids-singlecell raises RuntimeError.""" - if not settings.gpu_available(): + if not settings.gpu_available: with pytest.raises(RuntimeError, match="GPU unavailable"): settings.device = "gpu" @@ -53,7 +53,7 @@ def my_func(x, y, *, n_jobs=1, device=None): def test_auto_device_falls_back_to_cpu(self): """Test auto device falls back to CPU when GPU unavailable.""" - if settings.gpu_available(): + if settings.gpu_available: pytest.skip("GPU is available") calls = [] @@ -203,7 +203,7 @@ def my_func(x, device=None): def test_gpu_errors_when_unavailable(self): """Test GPU raises error when unavailable.""" - if settings.gpu_available(): + if settings.gpu_available: pytest.skip("GPU is available") @gpu_dispatch() From 5034343dfbf22e0151ffc3cf70090399fab6b369 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Feb 2026 10:07:44 +0000 Subject: [PATCH 38/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/_settings/_settings.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/squidpy/_settings/_settings.py b/src/squidpy/_settings/_settings.py index 98883f45a..038abe62b 100644 --- a/src/squidpy/_settings/_settings.py +++ b/src/squidpy/_settings/_settings.py @@ -37,7 +37,9 @@ def device(self, value: DeviceType) -> None: if value not in get_args(DeviceType): raise ValueError(f"device must be one of {get_args(DeviceType)}, got {value!r}") if value == "gpu" and not self.gpu_available: - raise RuntimeError("GPU unavailable. Install: pip install squidpy[gpu-cuda12] or with [gpu-cuda11] for CUDA 11 support.") + raise RuntimeError( + "GPU unavailable. Install: pip install squidpy[gpu-cuda12] or with [gpu-cuda11] for CUDA 11 support." + ) _device_var.set(value) From 8aae2eed7a945a72fa84478a959315c1bccc5cd7 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 2 Feb 2026 13:31:56 +0100 Subject: [PATCH 39/68] revisions --- src/squidpy/_docs.py | 6 ++++-- src/squidpy/_settings/_settings.py | 5 +++-- src/squidpy/gr/_gpu.py | 2 +- src/squidpy/gr/_ppatterns.py | 1 + 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/squidpy/_docs.py b/src/squidpy/_docs.py index f263dbb84..32019b177 100644 --- a/src/squidpy/_docs.py +++ b/src/squidpy/_docs.py @@ -128,11 +128,13 @@ def decorator2(obj: Any) -> Any: _parallelize_device = f"{_n_jobs}{_CPU_ONLY}\n{_backend}{_CPU_ONLY}\n{_show_progress_bar}{_CPU_ONLY}" _seed_device = f"""\ seed - Random seed for reproducibility.{_CPU_ONLY}""" + Random seed for reproducibility.{_CPU_ONLY} +""" _device = """\ device Device for computation: ``'cpu'``, ``'gpu'``, or ``None`` (use ``squidpy.settings.device``). - When ``'gpu'``, dispatches to :doc:`rapids_singlecell ` for GPU-accelerated computation.""" + When ``'gpu'``, dispatches to :doc:`rapids_singlecell ` for GPU-accelerated computation. +""" _channels = """\ channels Channels for this feature is computed. If `None`, use all channels.""" diff --git a/src/squidpy/_settings/_settings.py b/src/squidpy/_settings/_settings.py index 98883f45a..1f2430438 100644 --- a/src/squidpy/_settings/_settings.py +++ b/src/squidpy/_settings/_settings.py @@ -5,9 +5,10 @@ from contextvars import ContextVar from typing import Literal, get_args -__all__ = ["settings", "DeviceType"] +__all__ = ["settings", "DeviceType", "GPU_UNAVAILABLE_MSG"] DeviceType = Literal["auto", "cpu", "gpu"] +GPU_UNAVAILABLE_MSG = "GPU unavailable. Install: pip install squidpy[gpu-cuda12] or with [gpu-cuda11] for CUDA 11 support." _device_var: ContextVar[DeviceType] = ContextVar("device", default="auto") @@ -37,7 +38,7 @@ def device(self, value: DeviceType) -> None: if value not in get_args(DeviceType): raise ValueError(f"device must be one of {get_args(DeviceType)}, got {value!r}") if value == "gpu" and not self.gpu_available: - raise RuntimeError("GPU unavailable. Install: pip install squidpy[gpu-cuda12] or with [gpu-cuda11] for CUDA 11 support.") + raise RuntimeError(GPU_UNAVAILABLE_MSG) _device_var.set(value) diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py index cb2a9cd5a..997e6e193 100644 --- a/src/squidpy/gr/_gpu.py +++ b/src/squidpy/gr/_gpu.py @@ -1,4 +1,4 @@ -"""GPU parameter handling for squidpy.gr functions. +"""GPU parameter handling for squidpy functions with GPU acceleration. Automatically determines CPU-only and GPU-only parameters by introspecting function signatures. Only special cases (custom validators) need explicit registry entries. diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 355c97ed9..5a7dfcd53 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -390,6 +390,7 @@ def co_occurrence( - :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds computed at ``interval``. """ + del device, n_jobs, backend, show_progress_bar # handled by gpu_dispatch decorator or unused on CPU if isinstance(adata, SpatialData): adata = adata.table _assert_categorical_obs(adata, key=cluster_key) From 25ebbb03e7e3b80f8170e66e0f8f307b3b8fc774 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Feb 2026 12:33:12 +0000 Subject: [PATCH 40/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/_settings/_settings.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/squidpy/_settings/_settings.py b/src/squidpy/_settings/_settings.py index 1f2430438..ad8f71494 100644 --- a/src/squidpy/_settings/_settings.py +++ b/src/squidpy/_settings/_settings.py @@ -8,7 +8,9 @@ __all__ = ["settings", "DeviceType", "GPU_UNAVAILABLE_MSG"] DeviceType = Literal["auto", "cpu", "gpu"] -GPU_UNAVAILABLE_MSG = "GPU unavailable. Install: pip install squidpy[gpu-cuda12] or with [gpu-cuda11] for CUDA 11 support." +GPU_UNAVAILABLE_MSG = ( + "GPU unavailable. Install: pip install squidpy[gpu-cuda12] or with [gpu-cuda11] for CUDA 11 support." +) _device_var: ContextVar[DeviceType] = ContextVar("device", default="auto") From 52e0b4e7e1d40a3b1eed8f243727fbff459d4e45 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 2 Feb 2026 14:09:06 +0100 Subject: [PATCH 41/68] cleanup --- src/squidpy/_settings/_dispatch.py | 64 ++++++++++++------------------ 1 file changed, 26 insertions(+), 38 deletions(-) diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py index 2411724fa..f631cfbd1 100644 --- a/src/squidpy/_settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -9,8 +9,8 @@ from collections.abc import Callable from typing import Any, Literal, TypeVar -from squidpy._settings._settings import settings -from squidpy.gr._gpu import check_exclusive_params, get_exclusive_params +from squidpy._settings._settings import GPU_UNAVAILABLE_MSG, settings +from squidpy.gr._gpu import GPU_PARAM_REGISTRY, check_exclusive_params, get_or_create_registry_entry __all__ = ["gpu_dispatch"] @@ -25,7 +25,7 @@ def _resolve_device(device: Literal["auto", "cpu", "gpu"] | None) -> Literal["cp return "cpu" if device == "gpu": if not settings.gpu_available: - raise RuntimeError("GPU unavailable. Install with: pip install squidpy[gpu-cuda12]") + raise RuntimeError(GPU_UNAVAILABLE_MSG) return "gpu" # auto return "gpu" if settings.gpu_available else "cpu" @@ -65,9 +65,7 @@ def gpu_dispatch( Automatically determines CPU-only and GPU-only parameters by comparing function signatures. Errors if user explicitly provides a value for - an exclusive parameter on the wrong device. GPU-only - parameters are also present in the CPU signature but only to - provide a way for the user to pass the parameter to the GPU function. + an exclusive parameter on the wrong device. Also injects a GPU note into the function's docstring. @@ -86,18 +84,6 @@ def decorator(func: F) -> F: # Inject GPU note into docstring func.__doc__ = _inject_gpu_note(func.__doc__, _gpu_func_name, gpu_module) - # Cache for exclusive params (computed lazily on first GPU call) - _exclusive_params_cache: dict[str, Any] | None = None - - def _get_exclusive_params() -> tuple[dict[str, Any], dict[str, Any]]: - nonlocal _exclusive_params_cache - if _exclusive_params_cache is None: - module = importlib.import_module(gpu_module) - gpu_func = getattr(module, _gpu_func_name) - cpu_only, gpu_only = get_exclusive_params(func, gpu_func) - _exclusive_params_cache = {"cpu_only": cpu_only, "gpu_only": gpu_only} - return _exclusive_params_cache["cpu_only"], _exclusive_params_cache["gpu_only"] - @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: sig = inspect.signature(func) @@ -115,37 +101,39 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: device = all_args.pop("device", None) resolved_device = _resolve_device(device) - if resolved_device == "gpu": - cpu_only_params, gpu_only_params = _get_exclusive_params() + # Get or create registry entry (populated once per function) + key = (gpu_module, _gpu_func_name) + if key not in GPU_PARAM_REGISTRY: + try: + module = importlib.import_module(gpu_module) + gpu_func = getattr(module, _gpu_func_name) + get_or_create_registry_entry(gpu_module, _gpu_func_name, func, gpu_func) + except (ImportError, AttributeError): + # GPU module not available, just run CPU function + return func(**all_args) + + entry = GPU_PARAM_REGISTRY[key] + if resolved_device == "gpu": # Check if user explicitly provided any CPU-only params - user_provided_cpu_only = user_provided & cpu_only_params - if user_provided_cpu_only: - check_exclusive_params(func_name, user_provided_cpu_only, all_args, "gpu") + cpu_only_names = set(entry.cpu_only_params.keys()) + user_provided_cpu_only = user_provided & cpu_only_names + check_exclusive_params(func_name, user_provided_cpu_only, all_args, "gpu", entry) # Remove CPU-only params before calling GPU func - for k in cpu_only_params: + for k in cpu_only_names: all_args.pop(k, None) - module = importlib.import_module(gpu_module) - gpu_func = getattr(module, _gpu_func_name) - return gpu_func(**all_args) + return entry.gpu_func(**all_args) # CPU path - try: - cpu_only_params, gpu_only_params = _get_exclusive_params() - except (ImportError, AttributeError): - # GPU module not available, just run CPU function - return func(**all_args) - # Check if user explicitly provided any GPU-only params - user_provided_gpu_only = user_provided & gpu_only_params - if user_provided_gpu_only: - gpu_only_values = {k: all_args[k] for k in user_provided_gpu_only} - check_exclusive_params(func_name, user_provided_gpu_only, gpu_only_values, "cpu") + gpu_only_names = set(entry.gpu_only_params.keys()) + user_provided_gpu_only = user_provided & gpu_only_names + check_exclusive_params(func_name, user_provided_gpu_only, all_args, "cpu", entry) # Remove GPU-only params before calling CPU func - for k in gpu_only_params: + for k in gpu_only_names: all_args.pop(k, None) return func(**all_args) From 60f1eba3f6f359a46258fcca06223512a2a4d757 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 2 Feb 2026 14:22:03 +0100 Subject: [PATCH 42/68] revert to old file --- src/squidpy/_settings/_dispatch.py | 8 ++++---- src/squidpy/gr/_gpu.py | 10 +++++----- tests/test_settings.py | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py index f631cfbd1..5668599c0 100644 --- a/src/squidpy/_settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -10,7 +10,7 @@ from typing import Any, Literal, TypeVar from squidpy._settings._settings import GPU_UNAVAILABLE_MSG, settings -from squidpy.gr._gpu import GPU_PARAM_REGISTRY, check_exclusive_params, get_or_create_registry_entry +from squidpy.gr._gpu import SPECIAL_PARAM_REGISTRY, check_exclusive_params, get_exclusive_params __all__ = ["gpu_dispatch"] @@ -103,16 +103,16 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # Get or create registry entry (populated once per function) key = (gpu_module, _gpu_func_name) - if key not in GPU_PARAM_REGISTRY: + if key not in SPECIAL_PARAM_REGISTRY: try: module = importlib.import_module(gpu_module) gpu_func = getattr(module, _gpu_func_name) - get_or_create_registry_entry(gpu_module, _gpu_func_name, func, gpu_func) + cpu_only, gpu_only = get_exclusive_params(func, gpu_func) except (ImportError, AttributeError): # GPU module not available, just run CPU function return func(**all_args) - entry = GPU_PARAM_REGISTRY[key] + entry = SPECIAL_PARAM_REGISTRY[key] if resolved_device == "gpu": # Check if user explicitly provided any CPU-only params diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py index 997e6e193..d1288a117 100644 --- a/src/squidpy/gr/_gpu.py +++ b/src/squidpy/gr/_gpu.py @@ -11,11 +11,11 @@ from dataclasses import dataclass from typing import Any -__all__ = ["SPECIAL_PARAM_REGISTRY", "ParamSpec", "check_exclusive_params", "get_exclusive_params"] +__all__ = ["SPECIAL_PARAM_REGISTRY", "check_exclusive_params", "get_exclusive_params"] @dataclass -class ParamSpec: +class GpuParamSpec: """Specification for a parameter with custom validation.""" validate_fn: Callable[[Any], str | None] @@ -29,11 +29,11 @@ def _attr_validator(value: Any) -> str | None: # Minimal registry: only for params that need custom validators -# Format: {func_name: {"cpu_only": {param: ParamSpec}, "gpu_only": {param: ParamSpec}}} -SPECIAL_PARAM_REGISTRY: dict[str, dict[str, dict[str, ParamSpec]]] = { +# Format: {func_name: {"cpu_only": {param: GpuParamSpec}, "gpu_only": {param: GpuParamSpec}}} +SPECIAL_PARAM_REGISTRY: dict[str, dict[str, dict[str, GpuParamSpec]]] = { "spatial_autocorr": { "cpu_only": { - "attr": ParamSpec(validate_fn=_attr_validator), + "attr": GpuParamSpec(validate_fn=_attr_validator), }, "gpu_only": {}, }, diff --git a/tests/test_settings.py b/tests/test_settings.py index c966d79e2..64a1d5d7e 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -7,7 +7,7 @@ import pytest from squidpy._settings import gpu_dispatch, settings -from squidpy.gr._gpu import SPECIAL_PARAM_REGISTRY, ParamSpec +from squidpy.gr._gpu import SPECIAL_PARAM_REGISTRY, GpuParamSpec class TestSettings: @@ -231,7 +231,7 @@ def gpu_my_func(x): registry = { "my_func": { - "cpu_only": {"custom_param": ParamSpec(validate_fn=my_validator)}, + "cpu_only": {"custom_param": GpuParamSpec(validate_fn=my_validator)}, "gpu_only": {}, } } From 99d55e90a04a530374e6292733edaae8593955ff Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Tue, 3 Feb 2026 13:45:23 +0100 Subject: [PATCH 43/68] remove device and rely on context manager --- src/squidpy/_settings/_dispatch.py | 12 +++--- src/squidpy/gr/_ligrec.py | 2 - src/squidpy/gr/_ppatterns.py | 6 +-- tests/test_settings.py | 65 +++++++++++++++++------------- 4 files changed, 42 insertions(+), 43 deletions(-) diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py index 5668599c0..e485eae72 100644 --- a/src/squidpy/_settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -7,7 +7,7 @@ import inspect import re from collections.abc import Callable -from typing import Any, Literal, TypeVar +from typing import Any, TypeVar from squidpy._settings._settings import GPU_UNAVAILABLE_MSG, settings from squidpy.gr._gpu import SPECIAL_PARAM_REGISTRY, check_exclusive_params, get_exclusive_params @@ -17,10 +17,9 @@ F = TypeVar("F", bound=Callable[..., Any]) -def _resolve_device(device: Literal["auto", "cpu", "gpu"] | None) -> Literal["cpu", "gpu"]: - """Resolve device arg to 'cpu' or 'gpu'.""" - if device is None: - device = settings.device +def _get_effective_device() -> str: + """Get effective device from settings, resolving 'auto'.""" + device = settings.device if device == "cpu": return "cpu" if device == "gpu": @@ -98,8 +97,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: bound.apply_defaults() all_args = dict(bound.arguments) - device = all_args.pop("device", None) - resolved_device = _resolve_device(device) + resolved_device = _get_effective_device() # Get or create registry entry (populated once per function) key = (gpu_module, _gpu_func_name) diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index c7a07b152..611a85047 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -649,7 +649,6 @@ def ligrec( copy: bool = False, key_added: str | None = None, gene_symbols: str | None = None, - device: Literal["cpu", "gpu"] | None = None, # prepare params interactions_params: Mapping[str, Any] = MappingProxyType({}), transmitter_params: Mapping[str, Any] = MappingProxyType({"categories": "ligand"}), @@ -675,7 +674,6 @@ def ligrec( %(PT_test.parameters)s gene_symbols Key in :attr:`anndata.AnnData.var` to use instead of :attr:`anndata.AnnData.var_names`. - %(device)s Returns ------- diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 5a7dfcd53..dbaade017 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -65,7 +65,6 @@ def spatial_autocorr( n_jobs: int | None = None, backend: str | None = None, show_progress_bar: bool | None = None, - device: Literal["cpu", "gpu"] | None = None, ) -> pd.DataFrame | None: """ Calculate Global Autocorrelation Statistic (Moran’s I or Geary's C). @@ -115,7 +114,6 @@ def spatial_autocorr( %(seed_device)s %(copy)s %(parallelize_device)s - %(device)s Returns ------- @@ -362,7 +360,6 @@ def co_occurrence( n_jobs: int | None = None, backend: str | None = None, show_progress_bar: bool | None = None, - device: Literal["cpu", "gpu"] | None = None, ) -> tuple[NDArrayA, NDArrayA] | None: """ Compute co-occurrence probability of clusters. @@ -377,7 +374,6 @@ def co_occurrence( of the given size will be used. %(copy)s %(parallelize_device)s - %(device)s Returns ------- @@ -390,7 +386,7 @@ def co_occurrence( - :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds computed at ``interval``. """ - del device, n_jobs, backend, show_progress_bar # handled by gpu_dispatch decorator or unused on CPU + del n_jobs, backend, show_progress_bar # handled by gpu_dispatch decorator or unused on CPU if isinstance(adata, SpatialData): adata = adata.table _assert_categorical_obs(adata, key=cluster_key) diff --git a/tests/test_settings.py b/tests/test_settings.py index 64a1d5d7e..54687f0da 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -44,12 +44,14 @@ def test_cpu_path(self): calls = [] @gpu_dispatch() - def my_func(x, y, *, n_jobs=1, device=None): + def my_func(x, y, *, n_jobs=1): calls.append((x, y, n_jobs)) return x + y - assert my_func(1, 2, device="cpu") == 3 + settings.device = "cpu" + assert my_func(1, 2) == 3 assert calls == [(1, 2, 1)] + settings.device = "auto" # reset def test_auto_device_falls_back_to_cpu(self): """Test auto device falls back to CPU when GPU unavailable.""" @@ -59,11 +61,12 @@ def test_auto_device_falls_back_to_cpu(self): calls = [] @gpu_dispatch() - def my_func(x, device=None): + def my_func(x): calls.append(x) return x * 2 - assert my_func(5, device="auto") == 10 + settings.device = "auto" + assert my_func(5) == 10 assert calls == [5] def test_gpu_path(self): @@ -77,14 +80,14 @@ def gpu_my_func(x): mock_module.my_func = gpu_my_func @gpu_dispatch(gpu_module="test_module") - def my_func(x, device=None): + def my_func(x): return "cpu_result" with ( - patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), + patch("squidpy._settings._dispatch._get_effective_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), ): - assert my_func(42, device="gpu") == "gpu_result" + assert my_func(42) == "gpu_result" def test_custom_gpu_func_name(self): """Test custom GPU function name.""" @@ -97,20 +100,20 @@ def custom_name(x): mock_module.custom_name = custom_name @gpu_dispatch(gpu_module="test_module", gpu_func_name="custom_name") - def my_func(x, device=None): + def my_func(x): return "cpu_result" with ( - patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), + patch("squidpy._settings._dispatch._get_effective_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), ): - assert my_func(42, device="gpu") == "gpu_result" + assert my_func(42) == "gpu_result" def test_preserves_function_metadata(self): """Test decorator preserves function name and injects GPU note.""" @gpu_dispatch() - def documented_func(x, device=None): + def documented_func(x): """Original docstring. Parameters @@ -135,23 +138,23 @@ def gpu_my_func(x): mock_module.my_func = gpu_my_func @gpu_dispatch(gpu_module="test_module") - def my_func(x, n_jobs=1, device=None): + def my_func(x, n_jobs=1): return "cpu_result" with ( - patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), + patch("squidpy._settings._dispatch._get_effective_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), ): # Not provided - should work - assert my_func(42, device="gpu") == "gpu_result" + assert my_func(42) == "gpu_result" # Explicitly provided (even if same as default) - should error with pytest.raises(ValueError, match="n_jobs.*only supported on CPU"): - my_func(42, n_jobs=1, device="gpu") + my_func(42, n_jobs=1) # Explicitly provided with different value - should also error with pytest.raises(ValueError, match="n_jobs.*only supported on CPU"): - my_func(42, n_jobs=4, device="gpu") + my_func(42, n_jobs=4) def test_gpu_only_params_error_on_cpu_if_provided(self): """Test GPU-only params raise error on CPU if user explicitly provided them. @@ -170,16 +173,18 @@ def gpu_my_func(x, gpu_batch_size=1000): # CPU func does NOT have gpu_batch_size @gpu_dispatch(gpu_module="test_module") - def my_func(x, device=None): + def my_func(x): return "cpu_result" + settings.device = "cpu" with patch("importlib.import_module", return_value=mock_module): # Not provided - should work - assert my_func(42, device="cpu") == "cpu_result" + assert my_func(42) == "cpu_result" # GPU-only param on CPU - Python raises TypeError (not in signature) with pytest.raises(TypeError, match="unexpected keyword argument"): - my_func(42, gpu_batch_size=500, device="cpu") + my_func(42, gpu_batch_size=500) + settings.device = "auto" # reset def test_function_with_no_exclusive_params(self): """Test that functions with matching signatures work transparently.""" @@ -193,13 +198,15 @@ def gpu_func(x): mock_module.my_func = gpu_func @gpu_dispatch(gpu_module="test_module") - def my_func(x, device=None): + def my_func(x): calls.append(x) return x * 3 + settings.device = "cpu" # Should work on CPU without issues - assert my_func(10, device="cpu") == 30 + assert my_func(10) == 30 assert calls == [10] + settings.device = "auto" # reset def test_gpu_errors_when_unavailable(self): """Test GPU raises error when unavailable.""" @@ -207,12 +214,12 @@ def test_gpu_errors_when_unavailable(self): pytest.skip("GPU is available") @gpu_dispatch() - def my_func(x, device=None): + def my_func(x): return x + 1 # Should raise error when GPU requested but unavailable with pytest.raises(RuntimeError, match="GPU unavailable"): - my_func(5, device="gpu") + settings.device = "gpu" def test_custom_validator_error(self): """Test custom validator raises appropriate error.""" @@ -239,25 +246,25 @@ def gpu_my_func(x): with patch.dict(SPECIAL_PARAM_REGISTRY, registry): @gpu_dispatch(gpu_module="test_module") - def my_func(x, custom_param="allowed", device=None): + def my_func(x, custom_param="allowed"): return "cpu_result" with ( - patch("squidpy._settings._dispatch._resolve_device", return_value="gpu"), + patch("squidpy._settings._dispatch._get_effective_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), ): # Allowed value - should work - assert my_func(42, custom_param="allowed", device="gpu") == "gpu_result" + assert my_func(42, custom_param="allowed") == "gpu_result" # Not allowed value - should error with pytest.raises(ValueError, match="value='bad' is not allowed on GPU"): - my_func(42, custom_param="bad", device="gpu") + my_func(42, custom_param="bad") def test_docstring_uses_custom_gpu_module(self): """Test that docstring GPU note uses the specified gpu_module.""" @gpu_dispatch(gpu_module="custom.module.path") - def my_func(x, device=None): + def my_func(x): """My function. Parameters @@ -273,7 +280,7 @@ def test_docstring_uses_custom_gpu_func_name(self): """Test that docstring GPU note uses the specified gpu_func_name.""" @gpu_dispatch(gpu_module="some.module", gpu_func_name="different_name") - def my_func(x, device=None): + def my_func(x): """My function. Parameters From 22edd86a886a67ca50cf69ff1ce9cb1c898d9184 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Tue, 3 Feb 2026 13:56:04 +0100 Subject: [PATCH 44/68] add device_kwargs --- src/squidpy/_settings/_dispatch.py | 72 ++++++------------ src/squidpy/gr/_gpu.py | 113 ----------------------------- src/squidpy/gr/_ligrec.py | 6 +- src/squidpy/gr/_ppatterns.py | 14 ++-- tests/test_settings.py | 6 +- 5 files changed, 42 insertions(+), 169 deletions(-) delete mode 100644 src/squidpy/gr/_gpu.py diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py index e485eae72..81947a079 100644 --- a/src/squidpy/_settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -4,13 +4,11 @@ import functools import importlib -import inspect import re from collections.abc import Callable from typing import Any, TypeVar from squidpy._settings._settings import GPU_UNAVAILABLE_MSG, settings -from squidpy.gr._gpu import SPECIAL_PARAM_REGISTRY, check_exclusive_params, get_exclusive_params __all__ = ["gpu_dispatch"] @@ -56,17 +54,18 @@ def _inject_gpu_note(doc: str | None, func_name: str, gpu_module: str) -> str | return doc + "\n\n" + _make_gpu_note(func_name, gpu_module) +# Cache for GPU functions +_GPU_FUNC_CACHE: dict[tuple[str, str], Callable[..., Any]] = {} + + def gpu_dispatch( gpu_module: str = "rapids_singlecell.gr", gpu_func_name: str | None = None, ) -> Callable[[F], F]: - """Decorator to dispatch to GPU implementation when device='gpu'. + """Decorator to dispatch to GPU implementation based on settings.device. - Automatically determines CPU-only and GPU-only parameters by comparing - function signatures. Errors if user explicitly provides a value for - an exclusive parameter on the wrong device. - - Also injects a GPU note into the function's docstring. + When device is 'gpu', calls the GPU implementation from the specified module. + The `device_kwargs` parameter (if present) is passed to the GPU function. Parameters ---------- @@ -85,56 +84,31 @@ def decorator(func: F) -> F: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - sig = inspect.signature(func) - try: - bound = sig.bind(*args, **kwargs) - except TypeError: - return func(*args, **kwargs) - - # Track what user actually provided (before defaults) - user_provided = set(bound.arguments.keys()) - - bound.apply_defaults() - all_args = dict(bound.arguments) - resolved_device = _get_effective_device() - # Get or create registry entry (populated once per function) + if resolved_device == "cpu": + # CPU path - just remove device_kwargs if present and call CPU func + kwargs.pop("device_kwargs", None) + return func(*args, **kwargs) + + # GPU path key = (gpu_module, _gpu_func_name) - if key not in SPECIAL_PARAM_REGISTRY: + if key not in _GPU_FUNC_CACHE: try: module = importlib.import_module(gpu_module) - gpu_func = getattr(module, _gpu_func_name) - cpu_only, gpu_only = get_exclusive_params(func, gpu_func) + _GPU_FUNC_CACHE[key] = getattr(module, _gpu_func_name) except (ImportError, AttributeError): - # GPU module not available, just run CPU function - return func(**all_args) - - entry = SPECIAL_PARAM_REGISTRY[key] - - if resolved_device == "gpu": - # Check if user explicitly provided any CPU-only params - cpu_only_names = set(entry.cpu_only_params.keys()) - user_provided_cpu_only = user_provided & cpu_only_names - check_exclusive_params(func_name, user_provided_cpu_only, all_args, "gpu", entry) - - # Remove CPU-only params before calling GPU func - for k in cpu_only_names: - all_args.pop(k, None) - - return entry.gpu_func(**all_args) + # GPU module not available, fall back to CPU + kwargs.pop("device_kwargs", None) + return func(*args, **kwargs) - # CPU path - # Check if user explicitly provided any GPU-only params - gpu_only_names = set(entry.gpu_only_params.keys()) - user_provided_gpu_only = user_provided & gpu_only_names - check_exclusive_params(func_name, user_provided_gpu_only, all_args, "cpu", entry) + gpu_func = _GPU_FUNC_CACHE[key] - # Remove GPU-only params before calling CPU func - for k in gpu_only_names: - all_args.pop(k, None) + # Extract device_kwargs and merge into kwargs for GPU call + device_kwargs = kwargs.pop("device_kwargs", None) or {} + kwargs.update(device_kwargs) - return func(**all_args) + return gpu_func(*args, **kwargs) return wrapper # type: ignore[return-value] diff --git a/src/squidpy/gr/_gpu.py b/src/squidpy/gr/_gpu.py deleted file mode 100644 index d1288a117..000000000 --- a/src/squidpy/gr/_gpu.py +++ /dev/null @@ -1,113 +0,0 @@ -"""GPU parameter handling for squidpy functions with GPU acceleration. - -Automatically determines CPU-only and GPU-only parameters by introspecting function signatures. -Only special cases (custom validators) need explicit registry entries. -""" - -from __future__ import annotations - -import inspect -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any - -__all__ = ["SPECIAL_PARAM_REGISTRY", "check_exclusive_params", "get_exclusive_params"] - - -@dataclass -class GpuParamSpec: - """Specification for a parameter with custom validation.""" - - validate_fn: Callable[[Any], str | None] - - -def _attr_validator(value: Any) -> str | None: - """Validator for attr param - error if not 'X' on GPU.""" - if value == "X": - return None - return f"attr={value!r} is not supported on GPU. Set device='cpu' to use other attributes." - - -# Minimal registry: only for params that need custom validators -# Format: {func_name: {"cpu_only": {param: GpuParamSpec}, "gpu_only": {param: GpuParamSpec}}} -SPECIAL_PARAM_REGISTRY: dict[str, dict[str, dict[str, GpuParamSpec]]] = { - "spatial_autocorr": { - "cpu_only": { - "attr": GpuParamSpec(validate_fn=_attr_validator), - }, - "gpu_only": {}, - }, -} - - -def get_exclusive_params(cpu_func: Callable[..., Any], gpu_func: Callable[..., Any]) -> tuple[set[str], set[str]]: - """Get CPU-only and GPU-only params by comparing function signatures. - - Parameters - ---------- - cpu_func - The CPU implementation function. - gpu_func - The GPU implementation function. - - Returns - ------- - Tuple of (cpu_only_params, gpu_only_params) as sets of param names. - """ - cpu_sig = inspect.signature(cpu_func) - gpu_sig = inspect.signature(gpu_func) - - cpu_params = set(cpu_sig.parameters.keys()) - gpu_params = set(gpu_sig.parameters.keys()) - - # CPU-only: in CPU sig but not in GPU sig (excluding 'device' which is handled separately) - cpu_only = cpu_params - gpu_params - {"device"} - - # GPU-only: in GPU sig but not in CPU sig - gpu_only = gpu_params - cpu_params - - return cpu_only, gpu_only - - -def check_exclusive_params( - func_name: str, - user_provided_exclusive: set[str], - param_values: dict[str, Any], - target_device: str, -) -> None: - """Check exclusive params, raise error if user explicitly provided any. - - Parameters - ---------- - func_name - Name of the function (for registry lookup). - user_provided_exclusive - Set of param names that user explicitly provided AND are exclusive to other device. - param_values - All argument values (for error messages and custom validators). - target_device - The device being used ('cpu' or 'gpu'). - - Raises - ------ - ValueError - If user explicitly provided an exclusive parameter. - """ - other_device = "gpu" if target_device == "cpu" else "cpu" - registry_key = "gpu_only" if target_device == "cpu" else "cpu_only" - registry = SPECIAL_PARAM_REGISTRY.get(func_name, {"cpu_only": {}, "gpu_only": {}}) - - for name in user_provided_exclusive: - value = param_values.get(name) - - # Check special validate_fn first (they may allow certain values) - if name in registry[registry_key]: - spec = registry[registry_key][name] - msg = spec.validate_fn(value) - if msg: - raise ValueError(msg) - continue - - # User explicitly provided an exclusive param - error - msg = f"{name}={value!r} is only supported on {other_device.upper()}. Use device={other_device!r} or remove this argument." - raise ValueError(msg) diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index 611a85047..88c10fe49 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -659,10 +659,10 @@ def ligrec( seed: int | None = None, alpha: float = 0.05, numba_parallel: bool | None = None, - # CPU-only params (must be None to allow dispatch to detect if user provided) n_jobs: int | None = None, backend: str | None = None, show_progress_bar: bool | None = None, + device_kwargs: dict[str, Any] | None = None, ) -> Mapping[str, pd.DataFrame] | None: """ %(PT_test.full_desc)s @@ -674,11 +674,15 @@ def ligrec( %(PT_test.parameters)s gene_symbols Key in :attr:`anndata.AnnData.var` to use instead of :attr:`anndata.AnnData.var_names`. + device_kwargs + Additional keyword arguments passed to the GPU implementation when ``squidpy.settings.device`` + is set to ``'gpu'``. Ignored on CPU. Returns ------- %(ligrec_test_returns)s """ # noqa: D400 + del device_kwargs # handled by gpu_dispatch decorator if isinstance(adata, SpatialData): adata = adata.table diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index dbaade017..9e6cd54a7 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -60,11 +60,11 @@ def spatial_autocorr( layer: str | None = None, seed: int | None = None, use_raw: bool = False, - use_sparse: bool | None = None, copy: bool = False, n_jobs: int | None = None, backend: str | None = None, show_progress_bar: bool | None = None, + device_kwargs: dict[str, Any] | None = None, ) -> pd.DataFrame | None: """ Calculate Global Autocorrelation Statistic (Moran’s I or Geary's C). @@ -108,12 +108,12 @@ def spatial_autocorr( attr Which attribute of :class:`~anndata.AnnData` to access. See ``genes`` parameter for more information. Can be only 'X' when effective device is 'gpu'. - use_sparse - If `True`, use sparse matrix representation for the input matrix. - Only used when ``device='gpu'``. Defaults to `True` on GPU. %(seed_device)s %(copy)s %(parallelize_device)s + device_kwargs + Additional keyword arguments passed to the GPU implementation when ``squidpy.settings.device`` + is set to ``'gpu'``. Ignored on CPU. Returns ------- @@ -360,6 +360,7 @@ def co_occurrence( n_jobs: int | None = None, backend: str | None = None, show_progress_bar: bool | None = None, + device_kwargs: dict[str, Any] | None = None, ) -> tuple[NDArrayA, NDArrayA] | None: """ Compute co-occurrence probability of clusters. @@ -374,6 +375,9 @@ def co_occurrence( of the given size will be used. %(copy)s %(parallelize_device)s + device_kwargs + Additional keyword arguments passed to the GPU implementation when ``squidpy.settings.device`` + is set to ``'gpu'``. Ignored on CPU. Returns ------- @@ -386,7 +390,7 @@ def co_occurrence( - :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds computed at ``interval``. """ - del n_jobs, backend, show_progress_bar # handled by gpu_dispatch decorator or unused on CPU + del n_jobs, backend, show_progress_bar, device_kwargs # handled by gpu_dispatch decorator or unused on CPU if isinstance(adata, SpatialData): adata = adata.table _assert_categorical_obs(adata, key=cluster_key) diff --git a/tests/test_settings.py b/tests/test_settings.py index 54687f0da..5f4d8a3a3 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -7,7 +7,7 @@ import pytest from squidpy._settings import gpu_dispatch, settings -from squidpy.gr._gpu import SPECIAL_PARAM_REGISTRY, GpuParamSpec +from squidpy._settings._dispatch import _DISPATCH_CACHE class TestSettings: @@ -39,6 +39,10 @@ def test_set_device_gpu_without_rsc(self): class TestGpuDispatch: """Test the gpu_dispatch decorator.""" + def setup_method(self): + """Clear dispatch cache before each test.""" + _DISPATCH_CACHE.clear() + def test_cpu_path(self): """Test CPU device calls original function.""" calls = [] From 5bc8ca40e33b4c8ee1db59e922de99c4bfe13bac Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Tue, 3 Feb 2026 14:05:07 +0100 Subject: [PATCH 45/68] pass args directly --- src/squidpy/_settings/_dispatch.py | 27 ++++++++++++++++++--------- src/squidpy/gr/_ppatterns.py | 11 ++++++++++- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py index 81947a079..0de6f5251 100644 --- a/src/squidpy/_settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -5,7 +5,7 @@ import functools import importlib import re -from collections.abc import Callable +from collections.abc import Callable, Mapping from typing import Any, TypeVar from squidpy._settings._settings import GPU_UNAVAILABLE_MSG, settings @@ -45,7 +45,7 @@ def _inject_gpu_note(doc: str | None, func_name: str, gpu_module: str) -> str | # Find "Parameters\n ----------" and capture the indentation (spaces only, not newline) match = re.search(r"\n([ \t]*)Parameters\s*\n\s*-+", doc) if match: - indent = match.group(1) # Capture only the spaces/tabs before Parameters + indent = match.group(1) gpu_note = _make_gpu_note(func_name, gpu_module, indent) insert_pos = match.start() return doc[:insert_pos] + "\n\n" + gpu_note + "\n" + doc[insert_pos:] @@ -61,11 +61,13 @@ def _inject_gpu_note(doc: str | None, func_name: str, gpu_module: str) -> str | def gpu_dispatch( gpu_module: str = "rapids_singlecell.gr", gpu_func_name: str | None = None, + validate_args: Mapping[str, Callable[[Any], None]] | None = None, ) -> Callable[[F], F]: """Decorator to dispatch to GPU implementation based on settings.device. - When device is 'gpu', calls the GPU implementation from the specified module. - The `device_kwargs` parameter (if present) is passed to the GPU function. + When device is 'gpu', calls the GPU implementation from the specified module, + passing all arguments through. The `device_kwargs` parameter is merged into + the call for GPU-specific options. Parameters ---------- @@ -73,13 +75,17 @@ def gpu_dispatch( Module path containing the GPU implementation. gpu_func_name Name of GPU function. Defaults to same name as decorated function. + validate_args + Mapping of parameter names to validation functions. Each validator is called + with the parameter value before GPU dispatch and should raise ValueError + if the value is not supported on GPU. Only called when dispatching to GPU. """ + _validate_args = validate_args or {} def decorator(func: F) -> F: func_name = func.__name__ _gpu_func_name = gpu_func_name or func_name - # Inject GPU note into docstring func.__doc__ = _inject_gpu_note(func.__doc__, _gpu_func_name, gpu_module) @functools.wraps(func) @@ -87,24 +93,27 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: resolved_device = _get_effective_device() if resolved_device == "cpu": - # CPU path - just remove device_kwargs if present and call CPU func kwargs.pop("device_kwargs", None) return func(*args, **kwargs) - # GPU path + # GPU path - run validators + for param_name, validator in _validate_args.items(): + if param_name in kwargs: + validator(kwargs[param_name]) + + # Get GPU function key = (gpu_module, _gpu_func_name) if key not in _GPU_FUNC_CACHE: try: module = importlib.import_module(gpu_module) _GPU_FUNC_CACHE[key] = getattr(module, _gpu_func_name) except (ImportError, AttributeError): - # GPU module not available, fall back to CPU kwargs.pop("device_kwargs", None) return func(*args, **kwargs) gpu_func = _GPU_FUNC_CACHE[key] - # Extract device_kwargs and merge into kwargs for GPU call + # Merge device_kwargs and call GPU function device_kwargs = kwargs.pop("device_kwargs", None) or {} kwargs.update(device_kwargs) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 9e6cd54a7..1a7674e41 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -36,6 +36,15 @@ __all__ = ["spatial_autocorr", "co_occurrence"] +def _validate_attr_for_gpu(value: Any) -> None: + """Validate attr parameter for GPU dispatch.""" + if value != "X": + raise ValueError( + f"attr={value!r} is not supported on GPU. " + "Use `squidpy.settings.device = 'cpu'` to use other attributes." + ) + + it = nt.int32 ft = nt.float32 tt = nt.UniTuple @@ -46,7 +55,7 @@ @d.dedent @inject_docs(key=Key.obsp.spatial_conn(), sp=SpatialAutocorr) -@gpu_dispatch() +@gpu_dispatch(validate_args={"attr": _validate_attr_for_gpu}) def spatial_autocorr( adata: AnnData | SpatialData, connectivity_key: str = Key.obsp.spatial_conn(), From 1f118da9e9e87b437451dd331094bb9cd7727e22 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 08:30:23 +0000 Subject: [PATCH 46/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_ppatterns.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 1a7674e41..212b9cdf2 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -40,8 +40,7 @@ def _validate_attr_for_gpu(value: Any) -> None: """Validate attr parameter for GPU dispatch.""" if value != "X": raise ValueError( - f"attr={value!r} is not supported on GPU. " - "Use `squidpy.settings.device = 'cpu'` to use other attributes." + f"attr={value!r} is not supported on GPU. Use `squidpy.settings.device = 'cpu'` to use other attributes." ) From 97f5355490f79b105cb19ccf54baae032f6ee041 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 09:41:05 +0100 Subject: [PATCH 47/68] non-thread version of gpu dispatch --- src/squidpy/_settings/_dispatch.py | 26 ++-- src/squidpy/_settings/_settings.py | 18 ++- tests/test_settings.py | 210 ++++++++--------------------- 3 files changed, 88 insertions(+), 166 deletions(-) diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py index 0de6f5251..461affd51 100644 --- a/src/squidpy/_settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -60,33 +60,31 @@ def _inject_gpu_note(doc: str | None, func_name: str, gpu_module: str) -> str | def gpu_dispatch( gpu_module: str = "rapids_singlecell.gr", - gpu_func_name: str | None = None, validate_args: Mapping[str, Callable[[Any], None]] | None = None, ) -> Callable[[F], F]: """Decorator to dispatch to GPU implementation based on settings.device. - When device is 'gpu', calls the GPU implementation from the specified module, - passing all arguments through. The `device_kwargs` parameter is merged into - the call for GPU-specific options. + When device is 'gpu', calls the GPU implementation from the specified module. + The ``device_kwargs`` parameter from the decorated function is merged into the + call for GPU-specific options. Arguments with ``None`` values are filtered out + to let the GPU function use its defaults. Parameters ---------- gpu_module Module path containing the GPU implementation. - gpu_func_name - Name of GPU function. Defaults to same name as decorated function. validate_args Mapping of parameter names to validation functions. Each validator is called with the parameter value before GPU dispatch and should raise ValueError - if the value is not supported on GPU. Only called when dispatching to GPU. + if the value is not supported on GPU. Validated arguments are removed from + kwargs before calling the GPU function. Only called when dispatching to GPU. """ _validate_args = validate_args or {} def decorator(func: F) -> F: func_name = func.__name__ - _gpu_func_name = gpu_func_name or func_name - func.__doc__ = _inject_gpu_note(func.__doc__, _gpu_func_name, gpu_module) + func.__doc__ = _inject_gpu_note(func.__doc__, func_name, gpu_module) @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: @@ -96,17 +94,18 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: kwargs.pop("device_kwargs", None) return func(*args, **kwargs) - # GPU path - run validators + # GPU path - run validators and remove validated args for param_name, validator in _validate_args.items(): if param_name in kwargs: validator(kwargs[param_name]) + kwargs.pop(param_name) # Get GPU function - key = (gpu_module, _gpu_func_name) + key = (gpu_module, func_name) if key not in _GPU_FUNC_CACHE: try: module = importlib.import_module(gpu_module) - _GPU_FUNC_CACHE[key] = getattr(module, _gpu_func_name) + _GPU_FUNC_CACHE[key] = getattr(module, func_name) except (ImportError, AttributeError): kwargs.pop("device_kwargs", None) return func(*args, **kwargs) @@ -117,6 +116,9 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: device_kwargs = kwargs.pop("device_kwargs", None) or {} kwargs.update(device_kwargs) + # Filter out None values to let GPU function use its defaults + kwargs = {k: v for k, v in kwargs.items() if v is not None} + return gpu_func(*args, **kwargs) return wrapper # type: ignore[return-value] diff --git a/src/squidpy/_settings/_settings.py b/src/squidpy/_settings/_settings.py index ad8f71494..1efcb1997 100644 --- a/src/squidpy/_settings/_settings.py +++ b/src/squidpy/_settings/_settings.py @@ -25,14 +25,28 @@ def _check_gpu_available() -> bool: class SqSettings: - """Global configuration for squidpy.""" + """Global configuration for squidpy. + + Attributes + ---------- + gpu_available + Whether GPU acceleration via rapids-singlecell is available. + device + Compute device setting: ``'auto'`` (default), ``'cpu'``, or ``'gpu'``. + When ``'auto'``, GPU is used if available, otherwise CPU. + """ def __init__(self) -> None: self.gpu_available: bool = _check_gpu_available() @property def device(self) -> DeviceType: - """Compute device: ``'auto'``, ``'cpu'``, or ``'gpu'``.""" + """Compute device: ``'auto'``, ``'cpu'``, or ``'gpu'``. + + When set to ``'auto'`` (default), GPU is used if rapids-singlecell + is installed, otherwise falls back to CPU. Setting to ``'gpu'`` + when GPU is unavailable raises a RuntimeError. + """ return _device_var.get() @device.setter diff --git a/tests/test_settings.py b/tests/test_settings.py index 5f4d8a3a3..5974a6352 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -7,7 +7,15 @@ import pytest from squidpy._settings import gpu_dispatch, settings -from squidpy._settings._dispatch import _DISPATCH_CACHE +from squidpy._settings._dispatch import _GPU_FUNC_CACHE + + +@pytest.fixture(autouse=True) +def clear_gpu_cache(): + """Clear GPU function cache before each test.""" + _GPU_FUNC_CACHE.clear() + yield + _GPU_FUNC_CACHE.clear() class TestSettings: @@ -39,10 +47,6 @@ def test_set_device_gpu_without_rsc(self): class TestGpuDispatch: """Test the gpu_dispatch decorator.""" - def setup_method(self): - """Clear dispatch cache before each test.""" - _DISPATCH_CACHE.clear() - def test_cpu_path(self): """Test CPU device calls original function.""" calls = [] @@ -77,7 +81,6 @@ def test_gpu_path(self): """Test GPU device dispatches to GPU module.""" mock_module = MagicMock() - # Must use real function for signature introspection def gpu_my_func(x): return "gpu_result" @@ -93,177 +96,95 @@ def my_func(x): ): assert my_func(42) == "gpu_result" - def test_custom_gpu_func_name(self): - """Test custom GPU function name.""" + def test_device_kwargs_passed_to_gpu(self): + """Test device_kwargs are merged and passed to GPU function.""" mock_module = MagicMock() + received_kwargs = {} - # Must use real function for signature introspection - def custom_name(x): + def gpu_my_func(x, use_sparse=False): + received_kwargs.update({"x": x, "use_sparse": use_sparse}) return "gpu_result" - mock_module.custom_name = custom_name + mock_module.my_func = gpu_my_func - @gpu_dispatch(gpu_module="test_module", gpu_func_name="custom_name") - def my_func(x): + @gpu_dispatch(gpu_module="test_module") + def my_func(x, device_kwargs=None): return "cpu_result" with ( patch("squidpy._settings._dispatch._get_effective_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), ): - assert my_func(42) == "gpu_result" + result = my_func(42, device_kwargs={"use_sparse": True}) + assert result == "gpu_result" + assert received_kwargs == {"x": 42, "use_sparse": True} - def test_preserves_function_metadata(self): - """Test decorator preserves function name and injects GPU note.""" + def test_device_kwargs_ignored_on_cpu(self): + """Test device_kwargs are stripped on CPU path.""" + calls = [] @gpu_dispatch() - def documented_func(x): - """Original docstring. - - Parameters - ---------- - x - Input value. - """ - return x + def my_func(x, device_kwargs=None): + calls.append(x) + return x * 2 - assert documented_func.__name__ == "documented_func" - assert "Original docstring." in documented_func.__doc__ - assert "GPU acceleration" in documented_func.__doc__ + settings.device = "cpu" + # device_kwargs should be stripped, not cause an error + assert my_func(5, device_kwargs={"use_sparse": True}) == 10 + assert calls == [5] + settings.device = "auto" # reset - def test_cpu_only_params_error_on_gpu_if_provided(self): - """Test CPU-only params raise error on GPU if user explicitly provided them.""" + def test_validate_args_on_gpu(self): + """Test validate_args runs validators before GPU dispatch.""" mock_module = MagicMock() + mock_module.my_func = MagicMock(return_value="gpu_result") - # GPU function without n_jobs param (CPU-only) - def gpu_my_func(x): - return "gpu_result" + def validate_attr(value): + if value != "X": + raise ValueError(f"attr={value!r} not supported on GPU") - mock_module.my_func = gpu_my_func - - @gpu_dispatch(gpu_module="test_module") - def my_func(x, n_jobs=1): + @gpu_dispatch(gpu_module="test_module", validate_args={"attr": validate_attr}) + def my_func(x, attr="X"): return "cpu_result" with ( patch("squidpy._settings._dispatch._get_effective_device", return_value="gpu"), patch("importlib.import_module", return_value=mock_module), ): - # Not provided - should work - assert my_func(42) == "gpu_result" - - # Explicitly provided (even if same as default) - should error - with pytest.raises(ValueError, match="n_jobs.*only supported on CPU"): - my_func(42, n_jobs=1) - - # Explicitly provided with different value - should also error - with pytest.raises(ValueError, match="n_jobs.*only supported on CPU"): - my_func(42, n_jobs=4) + # Valid value should work + assert my_func(42, attr="X") == "gpu_result" - def test_gpu_only_params_error_on_cpu_if_provided(self): - """Test GPU-only params raise error on CPU if user explicitly provided them. + # Invalid value should raise + with pytest.raises(ValueError, match="attr='obs' not supported on GPU"): + my_func(42, attr="obs") - GPU-only params are those in GPU signature but NOT in CPU signature. - If user tries to pass a GPU-only param on CPU, Python raises TypeError - (unexpected keyword argument) unless the CPU func accepts **kwargs. - """ - mock_module = MagicMock() - - # GPU func has gpu_batch_size (GPU-only, not in CPU sig) - def gpu_my_func(x, gpu_batch_size=1000): - return "gpu_result" - - mock_module.my_func = gpu_my_func - - # CPU func does NOT have gpu_batch_size - @gpu_dispatch(gpu_module="test_module") - def my_func(x): - return "cpu_result" - - settings.device = "cpu" - with patch("importlib.import_module", return_value=mock_module): - # Not provided - should work - assert my_func(42) == "cpu_result" - - # GPU-only param on CPU - Python raises TypeError (not in signature) - with pytest.raises(TypeError, match="unexpected keyword argument"): - my_func(42, gpu_batch_size=500) - settings.device = "auto" # reset - - def test_function_with_no_exclusive_params(self): - """Test that functions with matching signatures work transparently.""" - calls = [] - mock_module = MagicMock() - - # GPU func has same signature - def gpu_func(x): - return "gpu_result" + def test_preserves_function_metadata(self): + """Test decorator preserves function name and injects GPU note.""" - mock_module.my_func = gpu_func + @gpu_dispatch() + def documented_func(x): + """Original docstring. - @gpu_dispatch(gpu_module="test_module") - def my_func(x): - calls.append(x) - return x * 3 + Parameters + ---------- + x + Input value. + """ + return x - settings.device = "cpu" - # Should work on CPU without issues - assert my_func(10) == 30 - assert calls == [10] - settings.device = "auto" # reset + assert documented_func.__name__ == "documented_func" + assert "Original docstring." in documented_func.__doc__ + assert "GPU acceleration" in documented_func.__doc__ def test_gpu_errors_when_unavailable(self): """Test GPU raises error when unavailable.""" if settings.gpu_available: pytest.skip("GPU is available") - @gpu_dispatch() - def my_func(x): - return x + 1 - # Should raise error when GPU requested but unavailable with pytest.raises(RuntimeError, match="GPU unavailable"): settings.device = "gpu" - def test_custom_validator_error(self): - """Test custom validator raises appropriate error.""" - mock_module = MagicMock() - - def my_validator(value): - if value != "allowed": - return f"value={value!r} is not allowed on GPU" - return None - - # GPU func without custom_param (CPU-only with validator) - def gpu_my_func(x): - return "gpu_result" - - mock_module.my_func = gpu_my_func - - registry = { - "my_func": { - "cpu_only": {"custom_param": GpuParamSpec(validate_fn=my_validator)}, - "gpu_only": {}, - } - } - - with patch.dict(SPECIAL_PARAM_REGISTRY, registry): - - @gpu_dispatch(gpu_module="test_module") - def my_func(x, custom_param="allowed"): - return "cpu_result" - - with ( - patch("squidpy._settings._dispatch._get_effective_device", return_value="gpu"), - patch("importlib.import_module", return_value=mock_module), - ): - # Allowed value - should work - assert my_func(42, custom_param="allowed") == "gpu_result" - - # Not allowed value - should error - with pytest.raises(ValueError, match="value='bad' is not allowed on GPU"): - my_func(42, custom_param="bad") - def test_docstring_uses_custom_gpu_module(self): """Test that docstring GPU note uses the specified gpu_module.""" @@ -280,18 +201,3 @@ def my_func(x): assert "custom.module.path.my_func" in my_func.__doc__ - def test_docstring_uses_custom_gpu_func_name(self): - """Test that docstring GPU note uses the specified gpu_func_name.""" - - @gpu_dispatch(gpu_module="some.module", gpu_func_name="different_name") - def my_func(x): - """My function. - - Parameters - ---------- - x - Input. - """ - return x - - assert "some.module.different_name" in my_func.__doc__ From b192fbe4531ac5bba32cbb3d9d44f0bce1436163 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 09:48:17 +0100 Subject: [PATCH 48/68] gpu func cache --- src/squidpy/_settings/_dispatch.py | 36 +++++++++++++++++------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py index 461affd51..d9985a7fc 100644 --- a/src/squidpy/_settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -54,8 +54,20 @@ def _inject_gpu_note(doc: str | None, func_name: str, gpu_module: str) -> str | return doc + "\n\n" + _make_gpu_note(func_name, gpu_module) -# Cache for GPU functions -_GPU_FUNC_CACHE: dict[tuple[str, str], Callable[..., Any]] = {} +@functools.cache +def _get_gpu_func(gpu_module: str, func_name: str) -> Callable[..., Any]: + """Get GPU function from module, with caching. + + + Raises + ------ + ImportError + If the GPU module cannot be imported. + AttributeError + If the function does not exist in the GPU module. + """ + module = importlib.import_module(gpu_module) + return getattr(module, func_name) def gpu_dispatch( @@ -94,29 +106,21 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: kwargs.pop("device_kwargs", None) return func(*args, **kwargs) - # GPU path - run validators and remove validated args + # GPU path + # run validators and remove validated args for param_name, validator in _validate_args.items(): if param_name in kwargs: validator(kwargs[param_name]) kwargs.pop(param_name) - # Get GPU function - key = (gpu_module, func_name) - if key not in _GPU_FUNC_CACHE: - try: - module = importlib.import_module(gpu_module) - _GPU_FUNC_CACHE[key] = getattr(module, func_name) - except (ImportError, AttributeError): - kwargs.pop("device_kwargs", None) - return func(*args, **kwargs) - - gpu_func = _GPU_FUNC_CACHE[key] + # get GPU function + gpu_func = _get_gpu_func(gpu_module, func_name) - # Merge device_kwargs and call GPU function + # merge device_kwargs and call GPU function device_kwargs = kwargs.pop("device_kwargs", None) or {} kwargs.update(device_kwargs) - # Filter out None values to let GPU function use its defaults + # filter out None values to let GPU function use its defaults kwargs = {k: v for k, v in kwargs.items() if v is not None} return gpu_func(*args, **kwargs) From 635957726a1b4a33458c16f4efdfb4c362a6b6c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 08:48:31 +0000 Subject: [PATCH 49/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_settings.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index 5974a6352..9bb4debee 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -200,4 +200,3 @@ def my_func(x): return x assert "custom.module.path.my_func" in my_func.__doc__ - From 9dca51b18cd6a14feb6fe818e9cef4396c2a9c9d Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 10:02:27 +0100 Subject: [PATCH 50/68] context manager --- src/squidpy/_settings/_dispatch.py | 17 +- src/squidpy/_settings/_settings.py | 49 ++++-- tests/test_settings.py | 239 +++++++++++++++++++++-------- 3 files changed, 213 insertions(+), 92 deletions(-) diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py index d9985a7fc..efff66f06 100644 --- a/src/squidpy/_settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -8,26 +8,13 @@ from collections.abc import Callable, Mapping from typing import Any, TypeVar -from squidpy._settings._settings import GPU_UNAVAILABLE_MSG, settings +from squidpy._settings._settings import settings __all__ = ["gpu_dispatch"] F = TypeVar("F", bound=Callable[..., Any]) -def _get_effective_device() -> str: - """Get effective device from settings, resolving 'auto'.""" - device = settings.device - if device == "cpu": - return "cpu" - if device == "gpu": - if not settings.gpu_available: - raise RuntimeError(GPU_UNAVAILABLE_MSG) - return "gpu" - # auto - return "gpu" if settings.gpu_available else "cpu" - - def _make_gpu_note(func_name: str, gpu_module: str, indent: str = "") -> str: lines = [ ".. note::", @@ -100,7 +87,7 @@ def decorator(func: F) -> F: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - resolved_device = _get_effective_device() + resolved_device = settings.device if resolved_device == "cpu": kwargs.pop("device_kwargs", None) diff --git a/src/squidpy/_settings/_settings.py b/src/squidpy/_settings/_settings.py index 1efcb1997..900be5dc0 100644 --- a/src/squidpy/_settings/_settings.py +++ b/src/squidpy/_settings/_settings.py @@ -2,16 +2,20 @@ from __future__ import annotations -from contextvars import ContextVar -from typing import Literal, get_args +from contextlib import contextmanager +from contextvars import ContextVar, Token +from typing import TYPE_CHECKING, Literal, get_args + +if TYPE_CHECKING: + from collections.abc import Generator __all__ = ["settings", "DeviceType", "GPU_UNAVAILABLE_MSG"] -DeviceType = Literal["auto", "cpu", "gpu"] +DeviceType = Literal["cpu", "gpu"] GPU_UNAVAILABLE_MSG = ( "GPU unavailable. Install: pip install squidpy[gpu-cuda12] or with [gpu-cuda11] for CUDA 11 support." ) -_device_var: ContextVar[DeviceType] = ContextVar("device", default="auto") +_device_var: ContextVar[DeviceType | None] = ContextVar("device", default=None) def _check_gpu_available() -> bool: @@ -32,8 +36,8 @@ class SqSettings: gpu_available Whether GPU acceleration via rapids-singlecell is available. device - Compute device setting: ``'auto'`` (default), ``'cpu'``, or ``'gpu'``. - When ``'auto'``, GPU is used if available, otherwise CPU. + Compute device. + Defaults to ``'gpu'`` if available, otherwise ``'cpu'``. """ def __init__(self) -> None: @@ -41,13 +45,15 @@ def __init__(self) -> None: @property def device(self) -> DeviceType: - """Compute device: ``'auto'``, ``'cpu'``, or ``'gpu'``. + """Compute device: ``'cpu'`` or ``'gpu'``. - When set to ``'auto'`` (default), GPU is used if rapids-singlecell - is installed, otherwise falls back to CPU. Setting to ``'gpu'`` - when GPU is unavailable raises a RuntimeError. + Defaults to ``'gpu'`` if rapids-singlecell is installed, otherwise ``'cpu'``. + Setting to ``'gpu'`` when GPU is unavailable raises a RuntimeError. """ - return _device_var.get() + value = _device_var.get() + if value is None: + return "gpu" if self.gpu_available else "cpu" + return value @device.setter def device(self, value: DeviceType) -> None: @@ -57,5 +63,26 @@ def device(self, value: DeviceType) -> None: raise RuntimeError(GPU_UNAVAILABLE_MSG) _device_var.set(value) + @contextmanager + def use_device(self, device: DeviceType) -> Generator[None, None, None]: + """Temporarily set the compute device within a context. + + Parameters + ---------- + device + The device to use. + + Examples + -------- + >>> with sq.settings.use_device("cpu"): + ... sq.gr.spatial_neighbors(adata) + """ + token: Token[DeviceType | None] = _device_var.set(_device_var.get()) + try: + self.device = device + yield + finally: + _device_var.reset(token) + settings = SqSettings() diff --git a/tests/test_settings.py b/tests/test_settings.py index 9bb4debee..373e17b68 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -2,46 +2,157 @@ from __future__ import annotations +import concurrent.futures +import sys +from pathlib import Path from unittest.mock import MagicMock, patch import pytest +# Ensure src is in path for imports +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + from squidpy._settings import gpu_dispatch, settings -from squidpy._settings._dispatch import _GPU_FUNC_CACHE +from squidpy._settings._dispatch import _get_gpu_func +from squidpy._settings._settings import _device_var @pytest.fixture(autouse=True) -def clear_gpu_cache(): - """Clear GPU function cache before each test.""" - _GPU_FUNC_CACHE.clear() +def reset_device(): + """Reset device state before and after each test.""" + _device_var.set(None) + _get_gpu_func.cache_clear() yield - _GPU_FUNC_CACHE.clear() + _device_var.set(None) + _get_gpu_func.cache_clear() class TestSettings: """Test the settings module.""" - def test_default_device(self): - """Test that default device is 'auto'.""" - settings.device = "auto" - assert settings.device == "auto" + def test_default_device_cpu_when_gpu_unavailable(self): + """Test that default device is 'cpu' when GPU unavailable.""" + if settings.gpu_available: + pytest.skip("GPU is available") + assert settings.device == "cpu" + + def test_default_device_gpu_when_available(self): + """Test that default device is 'gpu' when GPU available.""" + if not settings.gpu_available: + pytest.skip("GPU is not available") + assert settings.device == "gpu" def test_set_device_cpu(self): """Test setting device to 'cpu'.""" settings.device = "cpu" assert settings.device == "cpu" - settings.device = "auto" # reset def test_set_device_invalid(self): """Test that invalid device raises ValueError.""" with pytest.raises(ValueError, match="device must be one of"): settings.device = "invalid" + def test_set_device_auto_invalid(self): + """Test that 'auto' is no longer a valid device.""" + with pytest.raises(ValueError, match="device must be one of"): + settings.device = "auto" + def test_set_device_gpu_without_rsc(self): """Test that setting device to 'gpu' without rapids-singlecell raises RuntimeError.""" - if not settings.gpu_available: - with pytest.raises(RuntimeError, match="GPU unavailable"): - settings.device = "gpu" + if settings.gpu_available: + pytest.skip("GPU is available") + with pytest.raises(RuntimeError, match="GPU unavailable"): + settings.device = "gpu" + + +class TestUseDeviceContextManager: + """Test the use_device context manager.""" + + def test_use_device_temporarily_sets_cpu(self): + """Test that use_device temporarily sets the device.""" + if settings.gpu_available: + pytest.skip("GPU is available - can't test CPU default") + + original = settings.device + with settings.use_device("cpu"): + assert settings.device == "cpu" + assert settings.device == original + + def test_use_device_restores_on_exception(self): + """Test that use_device restores device even on exception.""" + original = settings.device + with pytest.raises(ValueError, match="test error"): + with settings.use_device("cpu"): + assert settings.device == "cpu" + raise ValueError("test error") + assert settings.device == original + + def test_use_device_invalid_raises(self): + """Test that use_device raises on invalid device.""" + with pytest.raises(ValueError, match="device must be one of"): + with settings.use_device("invalid"): + pass + + def test_use_device_gpu_without_rsc_raises(self): + """Test that use_device('gpu') raises when GPU unavailable.""" + if settings.gpu_available: + pytest.skip("GPU is available") + with pytest.raises(RuntimeError, match="GPU unavailable"): + with settings.use_device("gpu"): + pass + + def test_nested_use_device(self): + """Test nested use_device contexts restore correctly.""" + if settings.gpu_available: + pytest.skip("GPU is available") + + original = settings.device + settings.device = "cpu" + + with settings.use_device("cpu"): + assert settings.device == "cpu" + with settings.use_device("cpu"): + assert settings.device == "cpu" + assert settings.device == "cpu" + assert settings.device == "cpu" + + def test_use_device_thread_isolation(self): + """Test that use_device is thread-safe with isolated contexts.""" + results = {} + + def thread_func(thread_id: int, device: str): + with settings.use_device(device): + # Small delay to increase chance of interleaving + import time + time.sleep(0.01) + results[thread_id] = settings.device + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + # Both threads use CPU since GPU may not be available + f1 = executor.submit(thread_func, 1, "cpu") + f2 = executor.submit(thread_func, 2, "cpu") + f1.result() + f2.result() + + assert results[1] == "cpu" + assert results[2] == "cpu" + + +class TestGpuFuncCache: + """Test the GPU function caching.""" + + def test_cache_info_available(self): + """Test that cache_info is accessible.""" + info = _get_gpu_func.cache_info() + assert hasattr(info, "hits") + assert hasattr(info, "misses") + + def test_cache_clear_works(self): + """Test that cache_clear works.""" + _get_gpu_func.cache_clear() + info = _get_gpu_func.cache_info() + assert info.hits == 0 + assert info.misses == 0 class TestGpuDispatch: @@ -56,26 +167,9 @@ def my_func(x, y, *, n_jobs=1): calls.append((x, y, n_jobs)) return x + y - settings.device = "cpu" - assert my_func(1, 2) == 3 - assert calls == [(1, 2, 1)] - settings.device = "auto" # reset - - def test_auto_device_falls_back_to_cpu(self): - """Test auto device falls back to CPU when GPU unavailable.""" - if settings.gpu_available: - pytest.skip("GPU is available") - - calls = [] - - @gpu_dispatch() - def my_func(x): - calls.append(x) - return x * 2 - - settings.device = "auto" - assert my_func(5) == 10 - assert calls == [5] + with settings.use_device("cpu"): + assert my_func(1, 2) == 3 + assert calls == [(1, 2, 1)] def test_gpu_path(self): """Test GPU device dispatches to GPU module.""" @@ -90,11 +184,10 @@ def gpu_my_func(x): def my_func(x): return "cpu_result" - with ( - patch("squidpy._settings._dispatch._get_effective_device", return_value="gpu"), - patch("importlib.import_module", return_value=mock_module), - ): - assert my_func(42) == "gpu_result" + with patch.object(settings, "gpu_available", True): + with settings.use_device("gpu"): + with patch("importlib.import_module", return_value=mock_module): + assert my_func(42) == "gpu_result" def test_device_kwargs_passed_to_gpu(self): """Test device_kwargs are merged and passed to GPU function.""" @@ -111,13 +204,12 @@ def gpu_my_func(x, use_sparse=False): def my_func(x, device_kwargs=None): return "cpu_result" - with ( - patch("squidpy._settings._dispatch._get_effective_device", return_value="gpu"), - patch("importlib.import_module", return_value=mock_module), - ): - result = my_func(42, device_kwargs={"use_sparse": True}) - assert result == "gpu_result" - assert received_kwargs == {"x": 42, "use_sparse": True} + with patch.object(settings, "gpu_available", True): + with settings.use_device("gpu"): + with patch("importlib.import_module", return_value=mock_module): + result = my_func(42, device_kwargs={"use_sparse": True}) + assert result == "gpu_result" + assert received_kwargs == {"x": 42, "use_sparse": True} def test_device_kwargs_ignored_on_cpu(self): """Test device_kwargs are stripped on CPU path.""" @@ -128,11 +220,10 @@ def my_func(x, device_kwargs=None): calls.append(x) return x * 2 - settings.device = "cpu" - # device_kwargs should be stripped, not cause an error - assert my_func(5, device_kwargs={"use_sparse": True}) == 10 - assert calls == [5] - settings.device = "auto" # reset + with settings.use_device("cpu"): + # device_kwargs should be stripped, not cause an error + assert my_func(5, device_kwargs={"use_sparse": True}) == 10 + assert calls == [5] def test_validate_args_on_gpu(self): """Test validate_args runs validators before GPU dispatch.""" @@ -147,16 +238,15 @@ def validate_attr(value): def my_func(x, attr="X"): return "cpu_result" - with ( - patch("squidpy._settings._dispatch._get_effective_device", return_value="gpu"), - patch("importlib.import_module", return_value=mock_module), - ): - # Valid value should work - assert my_func(42, attr="X") == "gpu_result" + with patch.object(settings, "gpu_available", True): + with settings.use_device("gpu"): + with patch("importlib.import_module", return_value=mock_module): + # Valid value should work + assert my_func(42, attr="X") == "gpu_result" - # Invalid value should raise - with pytest.raises(ValueError, match="attr='obs' not supported on GPU"): - my_func(42, attr="obs") + # Invalid value should raise + with pytest.raises(ValueError, match="attr='obs' not supported on GPU"): + my_func(42, attr="obs") def test_preserves_function_metadata(self): """Test decorator preserves function name and injects GPU note.""" @@ -176,14 +266,31 @@ def documented_func(x): assert "Original docstring." in documented_func.__doc__ assert "GPU acceleration" in documented_func.__doc__ - def test_gpu_errors_when_unavailable(self): - """Test GPU raises error when unavailable.""" - if settings.gpu_available: - pytest.skip("GPU is available") + def test_gpu_import_error_propagates(self): + """Test ImportError propagates when GPU module not found.""" - # Should raise error when GPU requested but unavailable - with pytest.raises(RuntimeError, match="GPU unavailable"): - settings.device = "gpu" + @gpu_dispatch(gpu_module="nonexistent_module") + def my_func(x): + return "cpu_result" + + with patch.object(settings, "gpu_available", True): + with settings.use_device("gpu"): + with pytest.raises(ImportError): + my_func(42) + + def test_gpu_attribute_error_propagates(self): + """Test AttributeError propagates when function not in GPU module.""" + mock_module = MagicMock(spec=[]) # Empty spec, no attributes + + @gpu_dispatch(gpu_module="test_module") + def my_func(x): + return "cpu_result" + + with patch.object(settings, "gpu_available", True): + with settings.use_device("gpu"): + with patch("importlib.import_module", return_value=mock_module): + with pytest.raises(AttributeError): + my_func(42) def test_docstring_uses_custom_gpu_module(self): """Test that docstring GPU note uses the specified gpu_module.""" From ac5421e125e8b3b50e72da9bbd15b0b1075f3a3d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 09:03:11 +0000 Subject: [PATCH 51/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_settings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_settings.py b/tests/test_settings.py index 373e17b68..c8c3307f0 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -124,6 +124,7 @@ def thread_func(thread_id: int, device: str): with settings.use_device(device): # Small delay to increase chance of interleaving import time + time.sleep(0.01) results[thread_id] = settings.device From 06c001a9c4d3f2d97e4760a6266d0ae0e7e60972 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 10:07:37 +0100 Subject: [PATCH 52/68] update --- tests/test_settings.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index 373e17b68..b64c61aed 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -3,15 +3,10 @@ from __future__ import annotations import concurrent.futures -import sys -from pathlib import Path from unittest.mock import MagicMock, patch import pytest -# Ensure src is in path for imports -sys.path.insert(0, str(Path(__file__).parent.parent / "src")) - from squidpy._settings import gpu_dispatch, settings from squidpy._settings._dispatch import _get_gpu_func from squidpy._settings._settings import _device_var From ba9e7e08c26680bee7dff5f17113a2e87c000581 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 10:09:10 +0100 Subject: [PATCH 53/68] remove redundant test --- tests/test_settings.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index ac4e56311..8092e4e5c 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -47,10 +47,6 @@ def test_set_device_invalid(self): with pytest.raises(ValueError, match="device must be one of"): settings.device = "invalid" - def test_set_device_auto_invalid(self): - """Test that 'auto' is no longer a valid device.""" - with pytest.raises(ValueError, match="device must be one of"): - settings.device = "auto" def test_set_device_gpu_without_rsc(self): """Test that setting device to 'gpu' without rapids-singlecell raises RuntimeError.""" From ae54591bb8ff8ae65683b340c770180559a82535 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 10:12:44 +0100 Subject: [PATCH 54/68] redundant docstrings --- src/squidpy/_docs.py | 11 +++++------ src/squidpy/gr/_ligrec.py | 4 +--- src/squidpy/gr/_ppatterns.py | 8 ++------ 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/src/squidpy/_docs.py b/src/squidpy/_docs.py index 32019b177..1d7cc0e33 100644 --- a/src/squidpy/_docs.py +++ b/src/squidpy/_docs.py @@ -130,11 +130,10 @@ def decorator2(obj: Any) -> Any: seed Random seed for reproducibility.{_CPU_ONLY} """ -_device = """\ -device - Device for computation: ``'cpu'``, ``'gpu'``, or ``None`` (use ``squidpy.settings.device``). - When ``'gpu'``, dispatches to :doc:`rapids_singlecell ` for GPU-accelerated computation. -""" +_device_kwargs = """\ +device_kwargs + Additional keyword arguments passed to the GPU implementation when ``squidpy.settings.device`` + is set to ``'gpu'``. Ignored on CPU.""" _channels = """\ channels Channels for this feature is computed. If `None`, use all channels.""" @@ -397,7 +396,7 @@ def decorator2(obj: Any) -> Any: parallelize=_parallelize, parallelize_device=_parallelize_device, seed_device=_seed_device, - device=_device, + device_kwargs=_device_kwargs, channels=_channels, segment_kwargs=_segment_kwargs, ligrec_test_returns=_ligrec_test_returns, diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index 88c10fe49..20bf68e91 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -674,9 +674,7 @@ def ligrec( %(PT_test.parameters)s gene_symbols Key in :attr:`anndata.AnnData.var` to use instead of :attr:`anndata.AnnData.var_names`. - device_kwargs - Additional keyword arguments passed to the GPU implementation when ``squidpy.settings.device`` - is set to ``'gpu'``. Ignored on CPU. + %(device_kwargs)s Returns ------- diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 212b9cdf2..b29aabbd6 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -119,9 +119,7 @@ def spatial_autocorr( %(seed_device)s %(copy)s %(parallelize_device)s - device_kwargs - Additional keyword arguments passed to the GPU implementation when ``squidpy.settings.device`` - is set to ``'gpu'``. Ignored on CPU. + %(device_kwargs)s Returns ------- @@ -383,9 +381,7 @@ def co_occurrence( of the given size will be used. %(copy)s %(parallelize_device)s - device_kwargs - Additional keyword arguments passed to the GPU implementation when ``squidpy.settings.device`` - is set to ``'gpu'``. Ignored on CPU. + %(device_kwargs)s Returns ------- From a13767305364c07e014d18b95e5c252d1215de6d Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 10:15:27 +0100 Subject: [PATCH 55/68] remove redundancy --- src/squidpy/_settings/_dispatch.py | 1 - src/squidpy/_settings/_settings.py | 2 +- tests/test_gpu.py | 21 ++++++++++++++------- tests/test_settings.py | 1 - 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py index efff66f06..4a2567114 100644 --- a/src/squidpy/_settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -45,7 +45,6 @@ def _inject_gpu_note(doc: str | None, func_name: str, gpu_module: str) -> str | def _get_gpu_func(gpu_module: str, func_name: str) -> Callable[..., Any]: """Get GPU function from module, with caching. - Raises ------ ImportError diff --git a/src/squidpy/_settings/_settings.py b/src/squidpy/_settings/_settings.py index 900be5dc0..f4c444977 100644 --- a/src/squidpy/_settings/_settings.py +++ b/src/squidpy/_settings/_settings.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from collections.abc import Generator -__all__ = ["settings", "DeviceType", "GPU_UNAVAILABLE_MSG"] +__all__ = ["settings", "DeviceType"] DeviceType = Literal["cpu", "gpu"] GPU_UNAVAILABLE_MSG = ( diff --git a/tests/test_gpu.py b/tests/test_gpu.py index a5e927b4c..5b1cfcf37 100644 --- a/tests/test_gpu.py +++ b/tests/test_gpu.py @@ -20,7 +20,8 @@ class TestGPUCoOccurrence: def test_co_occurrence_gpu(self, adata): """Test co_occurrence with GPU device.""" - result = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True, device="gpu") + with settings.use_device("gpu"): + result = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True) assert result is not None arr, interval = result @@ -29,8 +30,10 @@ def test_co_occurrence_gpu(self, adata): def test_co_occurrence_gpu_vs_cpu(self, adata): """Test that GPU and CPU results are approximately equal.""" - cpu_arr, cpu_interval = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True, device="cpu") - gpu_arr, gpu_interval = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True, device="gpu") + with settings.use_device("cpu"): + cpu_arr, cpu_interval = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True) + with settings.use_device("gpu"): + gpu_arr, gpu_interval = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True) np.testing.assert_allclose(cpu_interval, gpu_interval, rtol=1e-5) np.testing.assert_allclose(cpu_arr, gpu_arr, rtol=1e-5) @@ -42,7 +45,8 @@ class TestGPUSpatialAutocorr: def test_spatial_autocorr_gpu(self, adata): """Test spatial_autocorr with GPU device.""" sq.gr.spatial_neighbors(adata) - result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True, device="gpu") + with settings.use_device("gpu"): + result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True) assert result is not None assert "I" in result.columns @@ -51,8 +55,10 @@ def test_spatial_autocorr_gpu(self, adata): def test_spatial_autocorr_gpu_vs_cpu(self, adata): """Test that GPU and CPU results are approximately equal.""" sq.gr.spatial_neighbors(adata) - cpu_result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True, device="cpu") - gpu_result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True, device="gpu") + with settings.use_device("cpu"): + cpu_result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True) + with settings.use_device("gpu"): + gpu_result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True) np.testing.assert_allclose(cpu_result["I"].values, gpu_result["I"].values, rtol=1e-3, equal_nan=True) @@ -62,7 +68,8 @@ class TestGPULigrec: def test_ligrec_gpu(self, adata): """Test ligrec with GPU device.""" - result = sq.gr.ligrec(adata, cluster_key="leiden", copy=True, device="gpu") + with settings.use_device("gpu"): + result = sq.gr.ligrec(adata, cluster_key="leiden", copy=True) assert result is not None assert "means" in result diff --git a/tests/test_settings.py b/tests/test_settings.py index 8092e4e5c..a14b473ab 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -47,7 +47,6 @@ def test_set_device_invalid(self): with pytest.raises(ValueError, match="device must be one of"): settings.device = "invalid" - def test_set_device_gpu_without_rsc(self): """Test that setting device to 'gpu' without rapids-singlecell raises RuntimeError.""" if settings.gpu_available: From 0ebb219b83193d19e3347c40c5012c22415a1ed6 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 4 Feb 2026 10:19:40 +0100 Subject: [PATCH 56/68] add uv.lock to gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 4818e5c2e..b60d6aa4f 100644 --- a/.gitignore +++ b/.gitignore @@ -146,3 +146,4 @@ data pixi.lock _version.py +uv.lock From a9517a97ff68ed8c07133ceaf0522cddf21ed129 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 4 Feb 2026 10:28:06 +0100 Subject: [PATCH 57/68] remove device leftover --- src/squidpy/gr/_ppatterns.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index b29aabbd6..497df3535 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -141,7 +141,7 @@ def spatial_autocorr( - :attr:`anndata.AnnData.uns` ``['moranI']`` - the above mentioned dataframe, if ``mode = {sp.MORAN.s!r}``. - :attr:`anndata.AnnData.uns` ``['gearyC']`` - the above mentioned dataframe, if ``mode = {sp.GEARY.s!r}``. """ - del device, use_sparse # device and use_sparse are handled by the gpu_dispatch decorator + del device_kwargs # device and use_sparse are handled by the gpu_dispatch decorator if isinstance(adata, SpatialData): adata = adata.table _assert_connectivity_key(adata, connectivity_key) From 4ebda67a5e66c3e27e079e93dfc5f880b569efa1 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 4 Feb 2026 10:39:37 +0100 Subject: [PATCH 58/68] spatial_neighbors doesnt match rsc much --- tests/test_gpu.py | 44 +++++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/tests/test_gpu.py b/tests/test_gpu.py index 5b1cfcf37..565bff642 100644 --- a/tests/test_gpu.py +++ b/tests/test_gpu.py @@ -15,25 +15,39 @@ ) +@pytest.fixture +def adata_filtered(adata): + """Filter adata to only include genes with non-zero variance. + + This avoids NaN values in GPU spatial_autocorr due to constant genes. + """ + # Calculate variance per gene and filter out zero/low variance genes + X = adata.X.toarray() if hasattr(adata.X, "toarray") else adata.X + gene_var = np.var(X, axis=0) + # Keep genes with variance > small threshold to avoid numerical issues + valid_genes = gene_var > 1e-6 + return adata[:, valid_genes].copy() + + class TestGPUCoOccurrence: """Test GPU-accelerated co_occurrence function.""" - def test_co_occurrence_gpu(self, adata): + def test_co_occurrence_gpu(self, adata_filtered): """Test co_occurrence with GPU device.""" with settings.use_device("gpu"): - result = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True) + result = sq.gr.co_occurrence(adata_filtered, cluster_key="leiden", copy=True) assert result is not None arr, interval = result assert arr.ndim == 3 - assert arr.shape[1] == arr.shape[0] == adata.obs["leiden"].nunique() + assert arr.shape[1] == arr.shape[0] == adata_filtered.obs["leiden"].nunique() - def test_co_occurrence_gpu_vs_cpu(self, adata): + def test_co_occurrence_gpu_vs_cpu(self, adata_filtered): """Test that GPU and CPU results are approximately equal.""" with settings.use_device("cpu"): - cpu_arr, cpu_interval = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True) + cpu_arr, cpu_interval = sq.gr.co_occurrence(adata_filtered, cluster_key="leiden", copy=True) with settings.use_device("gpu"): - gpu_arr, gpu_interval = sq.gr.co_occurrence(adata, cluster_key="leiden", copy=True) + gpu_arr, gpu_interval = sq.gr.co_occurrence(adata_filtered, cluster_key="leiden", copy=True) np.testing.assert_allclose(cpu_interval, gpu_interval, rtol=1e-5) np.testing.assert_allclose(cpu_arr, gpu_arr, rtol=1e-5) @@ -42,23 +56,23 @@ def test_co_occurrence_gpu_vs_cpu(self, adata): class TestGPUSpatialAutocorr: """Test GPU-accelerated spatial_autocorr function.""" - def test_spatial_autocorr_gpu(self, adata): + def test_spatial_autocorr_gpu(self, adata_filtered): """Test spatial_autocorr with GPU device.""" - sq.gr.spatial_neighbors(adata) + sq.gr.spatial_neighbors(adata_filtered) with settings.use_device("gpu"): - result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True) + result = sq.gr.spatial_autocorr(adata_filtered, mode="moran", copy=True) assert result is not None assert "I" in result.columns assert "pval_norm" in result.columns - def test_spatial_autocorr_gpu_vs_cpu(self, adata): + def test_spatial_autocorr_gpu_vs_cpu(self, adata_filtered): """Test that GPU and CPU results are approximately equal.""" - sq.gr.spatial_neighbors(adata) + sq.gr.spatial_neighbors(adata_filtered) with settings.use_device("cpu"): - cpu_result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True) + cpu_result = sq.gr.spatial_autocorr(adata_filtered, mode="moran", copy=True) with settings.use_device("gpu"): - gpu_result = sq.gr.spatial_autocorr(adata, mode="moran", copy=True) + gpu_result = sq.gr.spatial_autocorr(adata_filtered, mode="moran", copy=True) np.testing.assert_allclose(cpu_result["I"].values, gpu_result["I"].values, rtol=1e-3, equal_nan=True) @@ -66,10 +80,10 @@ def test_spatial_autocorr_gpu_vs_cpu(self, adata): class TestGPULigrec: """Test GPU-accelerated ligrec function.""" - def test_ligrec_gpu(self, adata): + def test_ligrec_gpu(self, adata_filtered): """Test ligrec with GPU device.""" with settings.use_device("gpu"): - result = sq.gr.ligrec(adata, cluster_key="leiden", copy=True) + result = sq.gr.ligrec(adata_filtered, cluster_key="leiden", copy=True) assert result is not None assert "means" in result From 93b06dc2f3e727f1d4b8feca6ce48e27c53bb0c8 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 11:04:50 +0100 Subject: [PATCH 59/68] remove redundant test --- tests/test_settings.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index a14b473ab..44f3a917b 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -91,21 +91,6 @@ def test_use_device_gpu_without_rsc_raises(self): with settings.use_device("gpu"): pass - def test_nested_use_device(self): - """Test nested use_device contexts restore correctly.""" - if settings.gpu_available: - pytest.skip("GPU is available") - - original = settings.device - settings.device = "cpu" - - with settings.use_device("cpu"): - assert settings.device == "cpu" - with settings.use_device("cpu"): - assert settings.device == "cpu" - assert settings.device == "cpu" - assert settings.device == "cpu" - def test_use_device_thread_isolation(self): """Test that use_device is thread-safe with isolated contexts.""" results = {} From f19bdc703ef626417f449d0da9a5dc8cd2310cd1 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 11:09:25 +0100 Subject: [PATCH 60/68] change the order of filtering --- src/squidpy/_settings/_dispatch.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py index 4a2567114..1b8dd9ac6 100644 --- a/src/squidpy/_settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -102,13 +102,12 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # get GPU function gpu_func = _get_gpu_func(gpu_module, func_name) + # filter out None values to let GPU function use its defaults + kwargs = {k: v for k, v in kwargs.items() if v is not None} # merge device_kwargs and call GPU function device_kwargs = kwargs.pop("device_kwargs", None) or {} kwargs.update(device_kwargs) - # filter out None values to let GPU function use its defaults - kwargs = {k: v for k, v in kwargs.items() if v is not None} - return gpu_func(*args, **kwargs) return wrapper # type: ignore[return-value] From c0a10cbe45a1ca744d8ef449a5cd4e22030dfd0e Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 11:15:56 +0100 Subject: [PATCH 61/68] make _SqSettings private --- src/squidpy/_settings/_settings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/squidpy/_settings/_settings.py b/src/squidpy/_settings/_settings.py index f4c444977..0dd33cc03 100644 --- a/src/squidpy/_settings/_settings.py +++ b/src/squidpy/_settings/_settings.py @@ -28,7 +28,7 @@ def _check_gpu_available() -> bool: return False -class SqSettings: +class _SqSettings: """Global configuration for squidpy. Attributes @@ -85,4 +85,4 @@ def use_device(self, device: DeviceType) -> Generator[None, None, None]: _device_var.reset(token) -settings = SqSettings() +settings = _SqSettings() From 0276e477cf30ccc6f1fab5088dcc09b21eef98bc Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 11:29:23 +0100 Subject: [PATCH 62/68] error when device_args is given --- src/squidpy/_docs.py | 9 +++++---- src/squidpy/_settings/_dispatch.py | 13 +++++++++---- src/squidpy/gr/_ligrec.py | 8 +------- tests/test_settings.py | 31 ++++++++++++++++++++++++++---- 4 files changed, 42 insertions(+), 19 deletions(-) diff --git a/src/squidpy/_docs.py b/src/squidpy/_docs.py index 1d7cc0e33..ad629e6f6 100644 --- a/src/squidpy/_docs.py +++ b/src/squidpy/_docs.py @@ -109,7 +109,7 @@ def decorator2(obj: Any) -> Any: _n_jobs = """\ n_jobs - Number of parallel jobs to use. + Number of parallel jobs to use. If ``None``, use all available cores. For ``backend="loky"``, the number of cores used by numba for each job spawned by the backend will be set to 1 in order to overcome the oversubscription issue in case you run @@ -119,10 +119,11 @@ def decorator2(obj: Any) -> Any: ``NUMBA_NUM_THREADS`` before running the program.""" _backend = """\ backend - Parallelization backend to use. See :class:`joblib.Parallel` for available options.""" + Parallelization backend to use. If ``None``, defaults to ``'loky'``. + See :class:`joblib.Parallel` for available options.""" _show_progress_bar = """\ show_progress_bar - Whether to show the progress bar or not.""" + Whether to show the progress bar. If ``None``, uses ``scanpy.settings.verbosity``.""" _parallelize = f"{_n_jobs}\n{_backend}\n{_show_progress_bar}" _parallelize_device = f"{_n_jobs}{_CPU_ONLY}\n{_backend}{_CPU_ONLY}\n{_show_progress_bar}{_CPU_ONLY}" @@ -133,7 +134,7 @@ def decorator2(obj: Any) -> Any: _device_kwargs = """\ device_kwargs Additional keyword arguments passed to the GPU implementation when ``squidpy.settings.device`` - is set to ``'gpu'``. Ignored on CPU.""" + is set to ``'gpu'``. Must be ``None`` or empty when device is ``'cpu'``.""" _channels = """\ channels Channels for this feature is computed. If `None`, use all channels.""" diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py index 1b8dd9ac6..3f8a1ca9f 100644 --- a/src/squidpy/_settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -86,10 +86,15 @@ def decorator(func: F) -> F: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - resolved_device = settings.device - - if resolved_device == "cpu": - kwargs.pop("device_kwargs", None) + effective_device = settings.device + + if effective_device == "cpu": + device_kwargs = kwargs.pop("device_kwargs", None) + if device_kwargs is not None and len(device_kwargs) > 0: + raise ValueError( + "device_kwargs should not be provided when squidpy.settings.device='cpu'. " + "Set squidpy.settings.device='gpu' or use settings.use_device('gpu') context manager." + ) return func(*args, **kwargs) # GPU path diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index 20bf68e91..b32494941 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -743,15 +743,9 @@ def _analysis( Percentage threshold for removing lowly expressed genes in clusters. %(n_perms)s %(seed)s - n_jobs - Number of parallel jobs to launch. numba_parallel Whether to use :func:`numba.prange` or not. If `None`, it's determined automatically. - backend - Parallelization backend to use. - show_progress_bar - Whether to show the progress bar. - + %(parallelize)s Returns ------- Tuple of the following format: diff --git a/tests/test_settings.py b/tests/test_settings.py index 44f3a917b..4eb1723d0 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -187,8 +187,8 @@ def my_func(x, device_kwargs=None): assert result == "gpu_result" assert received_kwargs == {"x": 42, "use_sparse": True} - def test_device_kwargs_ignored_on_cpu(self): - """Test device_kwargs are stripped on CPU path.""" + def test_device_kwargs_none_on_cpu(self): + """Test device_kwargs=None is allowed on CPU path.""" calls = [] @gpu_dispatch() @@ -197,10 +197,33 @@ def my_func(x, device_kwargs=None): return x * 2 with settings.use_device("cpu"): - # device_kwargs should be stripped, not cause an error - assert my_func(5, device_kwargs={"use_sparse": True}) == 10 + assert my_func(5, device_kwargs=None) == 10 assert calls == [5] + def test_device_kwargs_empty_on_cpu(self): + """Test device_kwargs={} is allowed on CPU path.""" + calls = [] + + @gpu_dispatch() + def my_func(x, device_kwargs=None): + calls.append(x) + return x * 2 + + with settings.use_device("cpu"): + assert my_func(5, device_kwargs={}) == 10 + assert calls == [5] + + def test_device_kwargs_error_on_cpu(self): + """Test device_kwargs with values raises error on CPU path.""" + + @gpu_dispatch() + def my_func(x, device_kwargs=None): + return x * 2 + + with settings.use_device("cpu"): + with pytest.raises(ValueError, match="device_kwargs=.* is not supported when device='cpu'"): + my_func(5, device_kwargs={"use_sparse": True}) + def test_validate_args_on_gpu(self): """Test validate_args runs validators before GPU dispatch.""" mock_module = MagicMock() From 53650c0a3b53d8dcf2431b8ab04ca509738ba452 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 11:33:06 +0100 Subject: [PATCH 63/68] update tests --- tests/test_settings.py | 28 +--------------------------- 1 file changed, 1 insertion(+), 27 deletions(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index 4eb1723d0..4d6cbd313 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -187,34 +187,8 @@ def my_func(x, device_kwargs=None): assert result == "gpu_result" assert received_kwargs == {"x": 42, "use_sparse": True} - def test_device_kwargs_none_on_cpu(self): - """Test device_kwargs=None is allowed on CPU path.""" - calls = [] - - @gpu_dispatch() - def my_func(x, device_kwargs=None): - calls.append(x) - return x * 2 - - with settings.use_device("cpu"): - assert my_func(5, device_kwargs=None) == 10 - assert calls == [5] - - def test_device_kwargs_empty_on_cpu(self): - """Test device_kwargs={} is allowed on CPU path.""" - calls = [] - - @gpu_dispatch() - def my_func(x, device_kwargs=None): - calls.append(x) - return x * 2 - - with settings.use_device("cpu"): - assert my_func(5, device_kwargs={}) == 10 - assert calls == [5] - def test_device_kwargs_error_on_cpu(self): - """Test device_kwargs with values raises error on CPU path.""" + """Test device_kwargs raises error on CPU path.""" @gpu_dispatch() def my_func(x, device_kwargs=None): From 49ed2fb88dfc227e13d8c52bd088d92919bc945f Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 11:35:07 +0100 Subject: [PATCH 64/68] update regex patternf or test --- src/squidpy/gr/_ppatterns.py | 2 +- tests/test_settings.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 497df3535..4a5cc23f0 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -54,7 +54,7 @@ def _validate_attr_for_gpu(value: Any) -> None: @d.dedent @inject_docs(key=Key.obsp.spatial_conn(), sp=SpatialAutocorr) -@gpu_dispatch(validate_args={"attr": _validate_attr_for_gpu}) +@gpu_dispatch(gpu_module="rapids_singlecell.gr", validate_args={"attr": _validate_attr_for_gpu}) def spatial_autocorr( adata: AnnData | SpatialData, connectivity_key: str = Key.obsp.spatial_conn(), diff --git a/tests/test_settings.py b/tests/test_settings.py index 4d6cbd313..fe104b6cb 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -195,7 +195,7 @@ def my_func(x, device_kwargs=None): return x * 2 with settings.use_device("cpu"): - with pytest.raises(ValueError, match="device_kwargs=.* is not supported when device='cpu'"): + with pytest.raises(ValueError, match="device_kwargs should not be provided"): my_func(5, device_kwargs={"use_sparse": True}) def test_validate_args_on_gpu(self): From 64fc33c4f9457052721c98cc5fd46350fdf66ce5 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 11:38:00 +0100 Subject: [PATCH 65/68] update gpu tests --- tests/test_gpu.py | 71 +++++++++++++++-------------------------------- 1 file changed, 22 insertions(+), 49 deletions(-) diff --git a/tests/test_gpu.py b/tests/test_gpu.py index 565bff642..e9e243b88 100644 --- a/tests/test_gpu.py +++ b/tests/test_gpu.py @@ -1,4 +1,8 @@ -"""Tests for GPU functionality (skipped in CI without GPU).""" +"""Tests for GPU functionality (skipped in CI without GPU). + +These tests verify GPU results match CPU results. Structure/correctness +of CPU outputs is tested elsewhere, so we only test equivalence here. +""" from __future__ import annotations @@ -8,7 +12,6 @@ import squidpy as sq from squidpy._settings import settings -# Skip all tests in this module if GPU is not available pytestmark = pytest.mark.skipif( not settings.gpu_available, reason="GPU tests require rapids-singlecell to be installed", @@ -17,33 +20,17 @@ @pytest.fixture def adata_filtered(adata): - """Filter adata to only include genes with non-zero variance. - - This avoids NaN values in GPU spatial_autocorr due to constant genes. - """ - # Calculate variance per gene and filter out zero/low variance genes + """Filter adata to genes with non-zero variance (avoids NaN in GPU spatial_autocorr).""" X = adata.X.toarray() if hasattr(adata.X, "toarray") else adata.X gene_var = np.var(X, axis=0) - # Keep genes with variance > small threshold to avoid numerical issues - valid_genes = gene_var > 1e-6 - return adata[:, valid_genes].copy() + return adata[:, gene_var > 1e-6].copy() -class TestGPUCoOccurrence: - """Test GPU-accelerated co_occurrence function.""" - - def test_co_occurrence_gpu(self, adata_filtered): - """Test co_occurrence with GPU device.""" - with settings.use_device("gpu"): - result = sq.gr.co_occurrence(adata_filtered, cluster_key="leiden", copy=True) +class TestGPUvsCPU: + """Test that GPU and CPU produce equivalent results.""" - assert result is not None - arr, interval = result - assert arr.ndim == 3 - assert arr.shape[1] == arr.shape[0] == adata_filtered.obs["leiden"].nunique() - - def test_co_occurrence_gpu_vs_cpu(self, adata_filtered): - """Test that GPU and CPU results are approximately equal.""" + def test_co_occurrence(self, adata_filtered): + """Test co_occurrence GPU vs CPU equivalence.""" with settings.use_device("cpu"): cpu_arr, cpu_interval = sq.gr.co_occurrence(adata_filtered, cluster_key="leiden", copy=True) with settings.use_device("gpu"): @@ -52,23 +39,10 @@ def test_co_occurrence_gpu_vs_cpu(self, adata_filtered): np.testing.assert_allclose(cpu_interval, gpu_interval, rtol=1e-5) np.testing.assert_allclose(cpu_arr, gpu_arr, rtol=1e-5) - -class TestGPUSpatialAutocorr: - """Test GPU-accelerated spatial_autocorr function.""" - - def test_spatial_autocorr_gpu(self, adata_filtered): - """Test spatial_autocorr with GPU device.""" + def test_spatial_autocorr(self, adata_filtered): + """Test spatial_autocorr GPU vs CPU equivalence.""" sq.gr.spatial_neighbors(adata_filtered) - with settings.use_device("gpu"): - result = sq.gr.spatial_autocorr(adata_filtered, mode="moran", copy=True) - - assert result is not None - assert "I" in result.columns - assert "pval_norm" in result.columns - def test_spatial_autocorr_gpu_vs_cpu(self, adata_filtered): - """Test that GPU and CPU results are approximately equal.""" - sq.gr.spatial_neighbors(adata_filtered) with settings.use_device("cpu"): cpu_result = sq.gr.spatial_autocorr(adata_filtered, mode="moran", copy=True) with settings.use_device("gpu"): @@ -76,15 +50,14 @@ def test_spatial_autocorr_gpu_vs_cpu(self, adata_filtered): np.testing.assert_allclose(cpu_result["I"].values, gpu_result["I"].values, rtol=1e-3, equal_nan=True) - -class TestGPULigrec: - """Test GPU-accelerated ligrec function.""" - - def test_ligrec_gpu(self, adata_filtered): - """Test ligrec with GPU device.""" + def test_ligrec(self, adata_filtered): + """Test ligrec GPU vs CPU equivalence.""" + with settings.use_device("cpu"): + cpu_result = sq.gr.ligrec(adata_filtered, cluster_key="leiden", copy=True, n_perms=5) with settings.use_device("gpu"): - result = sq.gr.ligrec(adata_filtered, cluster_key="leiden", copy=True) + gpu_result = sq.gr.ligrec(adata_filtered, cluster_key="leiden", copy=True, n_perms=5) - assert result is not None - assert "means" in result - assert "pvalues" in result + # Compare means (deterministic) + np.testing.assert_allclose( + cpu_result["means"].values, gpu_result["means"].values, rtol=1e-5, equal_nan=True + ) From 4f540adfdf7914dd6b774cdbeb5da5392cd612e2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 10:38:11 +0000 Subject: [PATCH 66/68] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_gpu.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_gpu.py b/tests/test_gpu.py index e9e243b88..2ddbe4d42 100644 --- a/tests/test_gpu.py +++ b/tests/test_gpu.py @@ -58,6 +58,4 @@ def test_ligrec(self, adata_filtered): gpu_result = sq.gr.ligrec(adata_filtered, cluster_key="leiden", copy=True, n_perms=5) # Compare means (deterministic) - np.testing.assert_allclose( - cpu_result["means"].values, gpu_result["means"].values, rtol=1e-5, equal_nan=True - ) + np.testing.assert_allclose(cpu_result["means"].values, gpu_result["means"].values, rtol=1e-5, equal_nan=True) From e3b331e8b30991bbd248f2e279068abf8fca36dc Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 11:49:34 +0100 Subject: [PATCH 67/68] clean up test settings --- tests/test_settings.py | 212 +++++++---------------------------------- 1 file changed, 32 insertions(+), 180 deletions(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index fe104b6cb..90f25e20e 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -2,7 +2,6 @@ from __future__ import annotations -import concurrent.futures from unittest.mock import MagicMock, patch import pytest @@ -22,114 +21,26 @@ def reset_device(): _get_gpu_func.cache_clear() -class TestSettings: - """Test the settings module.""" +class TestDeviceSettings: + """Test device property and use_device context manager.""" - def test_default_device_cpu_when_gpu_unavailable(self): - """Test that default device is 'cpu' when GPU unavailable.""" - if settings.gpu_available: - pytest.skip("GPU is available") - assert settings.device == "cpu" - - def test_default_device_gpu_when_available(self): - """Test that default device is 'gpu' when GPU available.""" - if not settings.gpu_available: - pytest.skip("GPU is not available") - assert settings.device == "gpu" - - def test_set_device_cpu(self): - """Test setting device to 'cpu'.""" - settings.device = "cpu" - assert settings.device == "cpu" - - def test_set_device_invalid(self): - """Test that invalid device raises ValueError.""" + def test_invalid_device_raises(self): + """Test invalid device raises ValueError.""" with pytest.raises(ValueError, match="device must be one of"): settings.device = "invalid" - - def test_set_device_gpu_without_rsc(self): - """Test that setting device to 'gpu' without rapids-singlecell raises RuntimeError.""" - if settings.gpu_available: - pytest.skip("GPU is available") - with pytest.raises(RuntimeError, match="GPU unavailable"): - settings.device = "gpu" - - -class TestUseDeviceContextManager: - """Test the use_device context manager.""" - - def test_use_device_temporarily_sets_cpu(self): - """Test that use_device temporarily sets the device.""" - if settings.gpu_available: - pytest.skip("GPU is available - can't test CPU default") - - original = settings.device - with settings.use_device("cpu"): - assert settings.device == "cpu" - assert settings.device == original - - def test_use_device_restores_on_exception(self): - """Test that use_device restores device even on exception.""" - original = settings.device - with pytest.raises(ValueError, match="test error"): - with settings.use_device("cpu"): - assert settings.device == "cpu" - raise ValueError("test error") - assert settings.device == original - - def test_use_device_invalid_raises(self): - """Test that use_device raises on invalid device.""" with pytest.raises(ValueError, match="device must be one of"): with settings.use_device("invalid"): pass - def test_use_device_gpu_without_rsc_raises(self): - """Test that use_device('gpu') raises when GPU unavailable.""" - if settings.gpu_available: - pytest.skip("GPU is available") + @pytest.mark.skipif(settings.gpu_available, reason="GPU is available") + def test_gpu_without_rsc_raises(self): + """Test setting GPU without rapids-singlecell raises RuntimeError.""" + with pytest.raises(RuntimeError, match="GPU unavailable"): + settings.device = "gpu" with pytest.raises(RuntimeError, match="GPU unavailable"): with settings.use_device("gpu"): pass - def test_use_device_thread_isolation(self): - """Test that use_device is thread-safe with isolated contexts.""" - results = {} - - def thread_func(thread_id: int, device: str): - with settings.use_device(device): - # Small delay to increase chance of interleaving - import time - - time.sleep(0.01) - results[thread_id] = settings.device - - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: - # Both threads use CPU since GPU may not be available - f1 = executor.submit(thread_func, 1, "cpu") - f2 = executor.submit(thread_func, 2, "cpu") - f1.result() - f2.result() - - assert results[1] == "cpu" - assert results[2] == "cpu" - - -class TestGpuFuncCache: - """Test the GPU function caching.""" - - def test_cache_info_available(self): - """Test that cache_info is accessible.""" - info = _get_gpu_func.cache_info() - assert hasattr(info, "hits") - assert hasattr(info, "misses") - - def test_cache_clear_works(self): - """Test that cache_clear works.""" - _get_gpu_func.cache_clear() - info = _get_gpu_func.cache_info() - assert info.hits == 0 - assert info.misses == 0 - class TestGpuDispatch: """Test the gpu_dispatch decorator.""" @@ -139,39 +50,21 @@ def test_cpu_path(self): calls = [] @gpu_dispatch() - def my_func(x, y, *, n_jobs=1): - calls.append((x, y, n_jobs)) + def my_func(x, y): + calls.append((x, y)) return x + y with settings.use_device("cpu"): assert my_func(1, 2) == 3 - assert calls == [(1, 2, 1)] + assert calls == [(1, 2)] - def test_gpu_path(self): - """Test GPU device dispatches to GPU module.""" + def test_gpu_dispatch_and_device_kwargs(self): + """Test GPU dispatch with device_kwargs.""" mock_module = MagicMock() - - def gpu_my_func(x): - return "gpu_result" - - mock_module.my_func = gpu_my_func - - @gpu_dispatch(gpu_module="test_module") - def my_func(x): - return "cpu_result" - - with patch.object(settings, "gpu_available", True): - with settings.use_device("gpu"): - with patch("importlib.import_module", return_value=mock_module): - assert my_func(42) == "gpu_result" - - def test_device_kwargs_passed_to_gpu(self): - """Test device_kwargs are merged and passed to GPU function.""" - mock_module = MagicMock() - received_kwargs = {} + received = {} def gpu_my_func(x, use_sparse=False): - received_kwargs.update({"x": x, "use_sparse": use_sparse}) + received.update({"x": x, "use_sparse": use_sparse}) return "gpu_result" mock_module.my_func = gpu_my_func @@ -183,9 +76,11 @@ def my_func(x, device_kwargs=None): with patch.object(settings, "gpu_available", True): with settings.use_device("gpu"): with patch("importlib.import_module", return_value=mock_module): - result = my_func(42, device_kwargs={"use_sparse": True}) - assert result == "gpu_result" - assert received_kwargs == {"x": 42, "use_sparse": True} + # Basic dispatch + assert my_func(42) == "gpu_result" + # With device_kwargs + assert my_func(42, device_kwargs={"use_sparse": True}) == "gpu_result" + assert received["use_sparse"] is True def test_device_kwargs_error_on_cpu(self): """Test device_kwargs raises error on CPU path.""" @@ -198,33 +93,31 @@ def my_func(x, device_kwargs=None): with pytest.raises(ValueError, match="device_kwargs should not be provided"): my_func(5, device_kwargs={"use_sparse": True}) - def test_validate_args_on_gpu(self): + def test_validate_args(self): """Test validate_args runs validators before GPU dispatch.""" mock_module = MagicMock() mock_module.my_func = MagicMock(return_value="gpu_result") - def validate_attr(value): - if value != "X": - raise ValueError(f"attr={value!r} not supported on GPU") - - @gpu_dispatch(gpu_module="test_module", validate_args={"attr": validate_attr}) + @gpu_dispatch( + gpu_module="test_module", + validate_args={ + "attr": lambda v: (_ for _ in ()).throw(ValueError(f"attr={v!r} invalid")) if v != "X" else None + }, + ) def my_func(x, attr="X"): return "cpu_result" with patch.object(settings, "gpu_available", True): with settings.use_device("gpu"): with patch("importlib.import_module", return_value=mock_module): - # Valid value should work assert my_func(42, attr="X") == "gpu_result" - - # Invalid value should raise - with pytest.raises(ValueError, match="attr='obs' not supported on GPU"): + with pytest.raises(ValueError, match="attr='obs' invalid"): my_func(42, attr="obs") - def test_preserves_function_metadata(self): + def test_preserves_metadata_and_docstring(self): """Test decorator preserves function name and injects GPU note.""" - @gpu_dispatch() + @gpu_dispatch(gpu_module="custom.module") def documented_func(x): """Original docstring. @@ -238,45 +131,4 @@ def documented_func(x): assert documented_func.__name__ == "documented_func" assert "Original docstring." in documented_func.__doc__ assert "GPU acceleration" in documented_func.__doc__ - - def test_gpu_import_error_propagates(self): - """Test ImportError propagates when GPU module not found.""" - - @gpu_dispatch(gpu_module="nonexistent_module") - def my_func(x): - return "cpu_result" - - with patch.object(settings, "gpu_available", True): - with settings.use_device("gpu"): - with pytest.raises(ImportError): - my_func(42) - - def test_gpu_attribute_error_propagates(self): - """Test AttributeError propagates when function not in GPU module.""" - mock_module = MagicMock(spec=[]) # Empty spec, no attributes - - @gpu_dispatch(gpu_module="test_module") - def my_func(x): - return "cpu_result" - - with patch.object(settings, "gpu_available", True): - with settings.use_device("gpu"): - with patch("importlib.import_module", return_value=mock_module): - with pytest.raises(AttributeError): - my_func(42) - - def test_docstring_uses_custom_gpu_module(self): - """Test that docstring GPU note uses the specified gpu_module.""" - - @gpu_dispatch(gpu_module="custom.module.path") - def my_func(x): - """My function. - - Parameters - ---------- - x - Input. - """ - return x - - assert "custom.module.path.my_func" in my_func.__doc__ + assert "custom.module.documented_func" in documented_func.__doc__ From 50ece87a069342996c6bb9f6e0b32bda9a29a0a7 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Wed, 4 Feb 2026 12:54:27 +0100 Subject: [PATCH 68/68] simplify --- src/squidpy/_settings/_dispatch.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/squidpy/_settings/_dispatch.py b/src/squidpy/_settings/_dispatch.py index 3f8a1ca9f..0a824a23b 100644 --- a/src/squidpy/_settings/_dispatch.py +++ b/src/squidpy/_settings/_dispatch.py @@ -86,9 +86,7 @@ def decorator(func: F) -> F: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - effective_device = settings.device - - if effective_device == "cpu": + if settings.device == "cpu": device_kwargs = kwargs.pop("device_kwargs", None) if device_kwargs is not None and len(device_kwargs) > 0: raise ValueError(