diff --git a/docs/api.md b/docs/api.md index fc9922a31..e975967b4 100644 --- a/docs/api.md +++ b/docs/api.md @@ -18,6 +18,12 @@ import squidpy as sq :toctree: api gr.spatial_neighbors + gr.spatial_neighbors_from_builder + gr.spatial_neighbors_knn + gr.spatial_neighbors_radius + gr.spatial_neighbors_delaunay + gr.spatial_neighbors_grid + gr.GraphMatrixT gr.SpatialNeighborsResult gr.mask_graph gr.nhood_enrichment @@ -109,3 +115,27 @@ import squidpy as sq datasets.visium_hne_image_crop datasets.visium_fluo_image_crop ``` + + +## Extensibility + +See the {doc}`extensibility guide ` for how to implement a custom graph builder. + +```{eval-rst} +.. module:: squidpy.gr.neighbors +.. currentmodule:: squidpy +.. autosummary:: + :toctree: api + + gr.neighbors.GraphBuilder + gr.neighbors.GraphBuilderCSR + gr.neighbors.GraphMatrixT + gr.neighbors.GraphPostprocessor + gr.neighbors.DistanceIntervalPostprocessor + gr.neighbors.PercentilePostprocessor + gr.neighbors.TransformPostprocessor + gr.neighbors.KNNBuilder + gr.neighbors.RadiusBuilder + gr.neighbors.DelaunayBuilder + gr.neighbors.GridBuilder +``` diff --git a/docs/extensibility.md b/docs/extensibility.md new file mode 100644 index 000000000..befd61b23 --- /dev/null +++ b/docs/extensibility.md @@ -0,0 +1,104 @@ +# Extensibility + +## Custom graph builders + +The `squidpy.gr.neighbors` module exposes two builder base classes: + +- {class}`~squidpy.gr.neighbors.GraphBuilder` is the generic builder pipeline. + Use it when you want to plug in a custom coordinate type or sparse-matrix backend. +- {class}`~squidpy.gr.neighbors.GraphBuilderCSR` is the CSR-specialized builder used + by the built-in graph construction strategies. Use it when your builder returns + {class}`~scipy.sparse.csr_matrix` objects and should reuse Squidpy's CSR-specific + postprocessors, sparse warning suppression, and multi-library combination. +- Reusable postprocessors such as + {class}`~squidpy.gr.neighbors.DistanceIntervalPostprocessor`, + {class}`~squidpy.gr.neighbors.PercentilePostprocessor`, and + {class}`~squidpy.gr.neighbors.TransformPostprocessor` are also exposed for + custom builder composition. + +### What to override + +| Base class | Method / property | Required | Purpose | +|---|---|---|---| +| {class}`~squidpy.gr.neighbors.GraphBuilder` | {meth}`~squidpy.gr.neighbors.GraphBuilder.build_graph` | yes | Construct and return ``(adj, dst)`` using the coordinate and matrix types of your custom backend. | +| {class}`~squidpy.gr.neighbors.GraphBuilder` | {meth}`~squidpy.gr.neighbors.GraphBuilder.postprocessors` | no | Return post-build processing steps for ``(adj, dst)``. You can either override this or pass ``postprocessors=...`` to ``super().__init__()``. | +| {class}`~squidpy.gr.neighbors.GraphBuilder` | {meth}`~squidpy.gr.neighbors.GraphBuilder.combine` | no | Combine per-library results when using ``library_key``. If you do not need ``library_key`` support, leaving this unimplemented is fine. | + + +The generic builder only defines the pipeline. The CSR-specialized builder adds +multi-library ``library_key`` combination and +{class}`~scipy.sparse.SparseEfficiencyWarning` suppression, while built-in and +custom CSR builders can compose the public reusable postprocessors for +distance-interval pruning, percentile filtering, and adjacency transforms. + +Here ``adj`` and ``dst`` are square sparse matrices of shape ``(n_obs, n_obs)`` +with matching sparsity structure: + +- ``adj`` is the connectivity / adjacency matrix. Non-zero entries mark edges in + the graph, and built-in builders typically use ``1.0`` for present edges. +- ``dst`` is the distance matrix for those same edges. For generic graphs this is + usually the Euclidean edge length. For grid builders it may instead encode + graph-distance semantics such as ring number. +- When subclassing {class}`~squidpy.gr.neighbors.GraphBuilderCSR`, both should be + returned as {class}`~scipy.sparse.csr_matrix`. +- For CSR-based builders, ``adj`` often behaves like a boolean or indicator + matrix describing whether an edge is present, even if it is stored with a + numeric dtype such as ``float32``. ``dst`` stores edge-associated values such + as distances and will often use a floating-point dtype. The exact dtype choice + is left to the builder implementation and may depend on performance, memory, + and numerical accuracy requirements. +- By convention, ``dst`` should have a zero diagonal, and ``adj`` should only + have a non-zero diagonal when ``set_diag=True``. + +### Example: fast radius search with SNN + +The built-in {class}`~squidpy.gr.neighbors.RadiusBuilder` uses scikit-learn's +``NearestNeighbors``. The [snnpy](https://github.com/nla-group/snn) library +provides a faster exact fixed-radius search based on PCA-based pruning. The +example below swaps the backend while keeping full compatibility with the rest +of the Squidpy graph pipeline: + +```python +import numpy as np +from scipy.sparse import csr_matrix +from snnpy import build_snn_model + +from squidpy.gr.neighbors import GraphBuilderCSR + + +class SNNRadiusBuilder(GraphBuilderCSR): + """Radius graph using the SNN fixed-radius search backend.""" + + def __init__(self, radius: float, **kwargs): + super().__init__(**kwargs) + self.radius = radius + + def build_graph(self, coords): + N = coords.shape[0] + model = build_snn_model(coords, verbose=0) + indices, dists = model.batch_query_radius( + coords, self.radius, return_distance=True, + ) + + row = np.repeat(np.arange(N), [len(idx) for idx in indices]) + col = np.concatenate(indices) + d = np.concatenate(dists).astype(np.float64) + + adj = csr_matrix( + (np.ones(len(row), dtype=np.float32), (row, col)), + shape=(N, N), + ) + dst = csr_matrix((d, (row, col)), shape=(N, N)) + + adj.setdiag(1.0 if self.set_diag else adj.diagonal()) + dst.setdiag(0.0) + return adj, dst +``` + +Use it like any other builder: + +```python +import squidpy as sq + +sq.gr.spatial_neighbors_from_builder(adata, SNNRadiusBuilder(radius=100.0)) +``` diff --git a/docs/index.md b/docs/index.md index e44bc2995..7ddaff4d7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -63,6 +63,7 @@ We are happy about any contributions! Before you start, check out our [contribut installation api classes + extensibility release_notes references contributing diff --git a/src/squidpy/_docs.py b/src/squidpy/_docs.py index a4596df75..4580b34ef 100644 --- a/src/squidpy/_docs.py +++ b/src/squidpy/_docs.py @@ -359,6 +359,33 @@ def decorator2(obj: Any) -> Any: If multiple `library_id`, column in :attr:`anndata.AnnData.obs` which stores mapping between ``library_id`` and obs.""" +_sdata_params = """\ +elements_to_coordinate_systems + A dictionary mapping element names of the SpatialData object to coordinate systems. + The elements can be either Shapes or Labels. For compatibility, the spatialdata table must annotate + all regions keys. Must not be ``None`` if ``adata`` is a :class:`spatialdata.SpatialData`. +table_key + Key in :attr:`spatialdata.SpatialData.tables` where the spatialdata table is stored. Must not be ``None`` if + ``adata`` is a :class:`spatialdata.SpatialData`.""" +_graph_common_params = """\ +percentile + Percentile of the distances to use as threshold. +transform + Adjacency matrix transform (``'spectral'``, ``'cosine'``, or ``None``). +set_diag + Whether to set the diagonal of the connectivities to ``1.0``. +key_added + Key which controls where the results are saved if ``copy = False``.""" +_spatial_neighbors_returns = """\ +If ``copy = True``, returns a :class:`~squidpy.gr.SpatialNeighborsResult` with the +spatial connectivities and distances matrices. + +Otherwise, modifies the ``adata`` with the following keys: + + - :attr:`anndata.AnnData.obsp` ``['{key_added}_connectivities']`` - the spatial connectivities. + - :attr:`anndata.AnnData.obsp` ``['{key_added}_distances']`` - the spatial distances. + - :attr:`anndata.AnnData.uns` ``['{key_added}']`` - :class:`dict` containing parameters.""" + d = DocstringProcessor( adata=_adata, img_container=_img_container, @@ -401,4 +428,7 @@ def decorator2(obj: Any) -> Any: groups=_groups, plotting_library_id=_plotting_library_id, library_key=_library_key, + sdata_params=_sdata_params, + graph_common_params=_graph_common_params, + spatial_neighbors_returns=_spatial_neighbors_returns, ) diff --git a/src/squidpy/gr/__init__.py b/src/squidpy/gr/__init__.py index 7bb9bb4a2..5a9a489fc 100644 --- a/src/squidpy/gr/__init__.py +++ b/src/squidpy/gr/__init__.py @@ -2,7 +2,18 @@ from __future__ import annotations -from squidpy.gr._build import SpatialNeighborsResult, mask_graph, spatial_neighbors +from squidpy.gr import neighbors +from squidpy.gr._build import ( + GraphMatrixT, + SpatialNeighborsResult, + mask_graph, + spatial_neighbors, + spatial_neighbors_delaunay, + spatial_neighbors_from_builder, + spatial_neighbors_grid, + spatial_neighbors_knn, + spatial_neighbors_radius, +) from squidpy.gr._ligrec import ligrec from squidpy.gr._nhood import ( NhoodEnrichmentResult, @@ -16,10 +27,17 @@ from squidpy.gr._sepal import sepal __all__ = [ + "GraphMatrixT", "SpatialNeighborsResult", "NhoodEnrichmentResult", + "neighbors", "mask_graph", "spatial_neighbors", + "spatial_neighbors_from_builder", + "spatial_neighbors_knn", + "spatial_neighbors_radius", + "spatial_neighbors_delaunay", + "spatial_neighbors_grid", "ligrec", "centrality_scores", "interaction_matrix", diff --git a/src/squidpy/gr/_build.py b/src/squidpy/gr/_build.py index cc844afff..2d9d1b3bd 100644 --- a/src/squidpy/gr/_build.py +++ b/src/squidpy/gr/_build.py @@ -3,30 +3,15 @@ from __future__ import annotations import warnings -from collections.abc import Iterable -from functools import partial -from itertools import chain -from typing import Any, NamedTuple, cast +from typing import Any, Generic, NamedTuple import geopandas as gpd import numpy as np import pandas as pd from anndata import AnnData from anndata.utils import make_index_unique -from fast_array_utils import stats as fau_stats -from numba import njit, prange -from scipy.sparse import ( - SparseEfficiencyWarning, - block_diag, - csr_array, - csr_matrix, - isspmatrix_csr, - spmatrix, -) -from scipy.spatial import Delaunay +from numba import njit from shapely import LineString, MultiPolygon, Polygon -from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances -from sklearn.neighbors import NearestNeighbors from spatialdata import SpatialData from spatialdata._core.centroids import get_centroids from spatialdata._core.query.relational_query import get_element_instances, match_element_to_table @@ -48,15 +33,156 @@ _assert_spatial_basis, _save_data, ) +from squidpy.gr.neighbors import ( + DelaunayBuilder, + GraphBuilder, + GraphMatrixT, + GridBuilder, + KNNBuilder, + RadiusBuilder, +) -__all__ = ["spatial_neighbors"] +__all__ = [ + "GraphMatrixT", + "SpatialNeighborsResult", + "spatial_neighbors", + "spatial_neighbors_from_builder", + "spatial_neighbors_knn", + "spatial_neighbors_radius", + "spatial_neighbors_delaunay", + "spatial_neighbors_grid", +] -class SpatialNeighborsResult(NamedTuple): +class SpatialNeighborsResult(NamedTuple, Generic[GraphMatrixT]): """Result of spatial_neighbors function.""" - connectivities: csr_matrix - distances: csr_matrix + connectivities: GraphMatrixT + distances: GraphMatrixT + + +def _resolve_graph_builder( + *, + coord_type: str | CoordType | None, + n_neighs: int | None, + radius: float | tuple[float, float] | None, + delaunay: bool | None, + n_rings: int | None, + percentile: float | None, + transform: str | Transform | None, + set_diag: bool | None, + has_spatial_uns: bool = False, +) -> GraphBuilder[Any, Any]: + n_neighs_was_set = n_neighs is not None + n_neighs = 6 if n_neighs is None else n_neighs + delaunay = False if delaunay is None else delaunay + n_rings = 1 if n_rings is None else n_rings + set_diag = False if set_diag is None else set_diag + + assert_positive(n_rings, name="n_rings") + assert_positive(n_neighs, name="n_neighs") + + transform = Transform.NONE if transform is None else Transform(transform) + if coord_type is None: + if radius is not None: + logg.warning( + "Graph creation with `radius` is only available for generic coordinates. " + f"Ignoring parameter `radius = {radius}`." + ) + coord_type = CoordType.GRID if has_spatial_uns else CoordType.GENERIC + else: + coord_type = CoordType(coord_type) + + common: dict[str, Any] = {"transform": transform, "set_diag": set_diag} + + if coord_type == CoordType.GRID: + if percentile is not None: + raise ValueError( + "`percentile` is not supported for grid coordinates. It only applies to generic (non-grid) graphs." + ) + return GridBuilder(n_neighs=n_neighs, **common, n_rings=n_rings, delaunay=delaunay) + if delaunay: + # TODO: below check should be removed once legacy mode spatial_neighbors is deprecated + if n_neighs_was_set: + warnings.warn( + "Parameter `n_neighs` is ignored when `delaunay=True` and will be removed in squidpy v2.0.0.", + FutureWarning, + stacklevel=3, + ) + return DelaunayBuilder(**common, radius=radius, percentile=percentile) + if radius is not None: + # TODO: below check should be removed once legacy mode spatial_neighbors is deprecated + if n_neighs_was_set: + warnings.warn( + "Parameter `n_neighs` is ignored when `radius` is set and will be removed in squidpy v2.0.0.", + FutureWarning, + stacklevel=3, + ) + return RadiusBuilder(**common, radius=radius, percentile=percentile) + return KNNBuilder(n_neighs=n_neighs, **common, percentile=percentile) + + +def _resolve_spatial_data( + data: AnnData | SpatialData, + *, + spatial_key: str, + elements_to_coordinate_systems: dict[str, str] | None, + table_key: str | None, + library_key: str | None, +) -> tuple[AnnData, str | None]: + """Resolve SpatialData to AnnData, returning (adata, library_key).""" + if isinstance(data, SpatialData): + sdata = data + assert elements_to_coordinate_systems is not None, ( + "Since `data` is a :class:`spatialdata.SpatialData`, `elements_to_coordinate_systems` must not be `None`." + ) + assert table_key is not None, ( + "Since `data` is a :class:`spatialdata.SpatialData`, `table_key` must not be `None`." + ) + elements, table = match_element_to_table(sdata, list(elements_to_coordinate_systems), table_key) + assert table.obs_names.equals(sdata.tables[table_key].obs_names), ( + "The spatialdata table must annotate all elements keys. Some elements are missing, please check the `elements_to_coordinate_systems` dictionary." + ) + regions, region_key, instance_key = get_table_keys(sdata.tables[table_key]) + regions = [regions] if isinstance(regions, str) else regions + ordered_regions_in_table = sdata.tables[table_key].obs[region_key].unique() + + # TODO: remove this after https://github.com/scverse/spatialdata/issues/614 + remove_centroids = {} + elem_instances = [] + for e in regions: + schema = get_model(elements[e]) + element_instances = get_element_instances(elements[e]).to_series() + if np.isin(0, element_instances.values) and (schema in (Labels2DModel, Labels3DModel)): + element_instances = element_instances.drop(index=0) + remove_centroids[e] = True + else: + remove_centroids[e] = False + elem_instances.append(element_instances) + + element_instances = pd.concat(elem_instances) + if (not np.all(element_instances.values == sdata.tables[table_key].obs[instance_key].values)) or ( + not np.all(ordered_regions_in_table == regions) + ): + raise ValueError( + "The spatialdata table must annotate all elements keys. Some elements are missing or not ordered correctly, please check the `elements_to_coordinate_systems` dictionary." + ) + centroids = [] + for region_ in ordered_regions_in_table: + cs = elements_to_coordinate_systems[region_] + centroid = get_centroids(sdata[region_], coordinate_system=cs)[["x", "y"]].compute() + + # TODO: remove this after https://github.com/scverse/spatialdata/issues/614 + if remove_centroids[region_]: + centroid = centroid[1:].copy() + centroids.append(centroid) + + sdata.tables[table_key].obsm[spatial_key] = np.concatenate(centroids) + adata = sdata.tables[table_key] + library_key = region_key + else: + adata = data + return adata, library_key @d.dedent @@ -68,19 +194,29 @@ def spatial_neighbors( table_key: str | None = None, library_key: str | None = None, coord_type: str | CoordType | None = None, - n_neighs: int = 6, + n_neighs: int | None = None, radius: float | tuple[float, float] | None = None, - delaunay: bool = False, - n_rings: int = 1, + delaunay: bool | None = None, + n_rings: int | None = None, percentile: float | None = None, transform: str | Transform | None = None, - set_diag: bool = False, + set_diag: bool | None = None, key_added: str = "spatial", copy: bool = False, ) -> SpatialNeighborsResult | None: """ Create a graph from spatial coordinates. + .. deprecated:: 1.6.0 + ``spatial_neighbors`` is deprecated and will be removed in squidpy + v1.7.0. Use one of the mode-specific functions instead: + + - :func:`spatial_neighbors_knn` + - :func:`spatial_neighbors_radius` + - :func:`spatial_neighbors_delaunay` + - :func:`spatial_neighbors_grid` + - :func:`spatial_neighbors_from_builder` + Parameters ---------- %(adata)s @@ -107,29 +243,100 @@ def spatial_neighbors( - `{c.GRID.s!r}` - number of neighboring tiles. - `{c.GENERIC.s!r}` - number of neighborhoods for non-grid data. Only used when ``delaunay = False``. + + Defaults to ``6``. radius - Only available when ``coord_type = {c.GENERIC.s!r}``. Depending on the type: + Only available when ``coord_type = {c.GENERIC.s!r}``. + Depending on the type: - :class:`float` - compute the graph based on neighborhood radius. - :class:`tuple` - prune the final graph to only contain edges in interval `[min(radius), max(radius)]`. delaunay Whether to compute the graph from Delaunay triangulation. Only used when ``coord_type = {c.GENERIC.s!r}``. + Defaults to ``False``. n_rings Number of rings of neighbors for grid data. Only used when ``coord_type = {c.GRID.s!r}``. + Defaults to ``1``. percentile Percentile of the distances to use as threshold. Only used when ``coord_type = {c.GENERIC.s!r}``. transform - Type of adjacency matrix transform. Valid options are: + Type of adjacency matrix transform. + Valid options are: - `{t.SPECTRAL.s!r}` - spectral transformation of the adjacency matrix. - `{t.COSINE.s!r}` - cosine transformation of the adjacency matrix. - `{t.NONE.v}` - no transformation of the adjacency matrix. set_diag Whether to set the diagonal of the spatial connectivities to `1.0`. + Defaults to ``False``. key_added Key which controls where the results are saved if ``copy = False``. %(copy)s + Notes + ----- + ``spatial_neighbors`` has 4 graph-construction modes: + + - Grid mode: + ``coord_type='grid'``. Uses ``n_neighs`` and ``n_rings``. + ``radius`` is ignored. ``delaunay`` is forwarded to the + underlying grid connectivity builder. This is the mode used + for Visium-like grid coordinates. + - Generic k-nearest-neighbor mode: + ``coord_type='generic'``, ``delaunay=False``, ``radius=None``. + Uses ``n_neighs``. + - Generic radius mode: + ``coord_type='generic'``, ``delaunay=False``, ``radius`` set. + Uses ``radius`` and builds a radius-based neighbor graph. + ``n_neighs`` is ignored and passing it is deprecated. + If ``radius`` is a tuple, the graph is built with the maximum + radius and then pruned to the interval + ``[min(radius), max(radius)]``. + - Generic Delaunay mode: + ``coord_type='generic'``, ``delaunay=True``. + Builds a Delaunay triangulation graph. ``n_neighs`` is + ignored by the triangulation and passing it is deprecated. + If ``radius`` is a tuple, it is used only as a + post-construction pruning interval. + + Across these modes: + + - ``percentile`` only affects generic graphs. + - ``transform`` and ``set_diag`` apply to all modes. + - By default, observations are not treated as their own + neighbors. The distance matrix always has a zero diagonal. + The connectivity matrix only gets a nonzero diagonal when + ``set_diag=True``. + + Argument precedence + ------------------- + The mode is resolved as follows: + + - If ``coord_type`` resolves to ``'grid'``, grid mode is used. + In that case ``radius`` is ignored. + - Otherwise, if ``delaunay=True``, Delaunay mode is used. + ``n_neighs`` is ignored (deprecated). + A tuple ``radius`` is only used afterward as a pruning + interval. A scalar ``radius`` is ignored. + - Otherwise, if ``radius`` is set, radius mode is used. + In this mode ``n_neighs`` is ignored (deprecated). + - Otherwise, k-nearest-neighbor mode is used. + + Grid-specific behavior + ---------------------- + Grid mode currently does not validate ``n_neighs`` to a fixed set + such as ``{{4, 6}}``. Internally it first queries the + ``n_neighs`` nearest candidates and then applies a distance-based + correction tuned for grid-like coordinates. As a result: + + - values such as ``n_neighs=4`` and ``n_neighs=6`` are the + intended square-grid and hex-grid choices, respectively; + - other values are accepted for backward compatibility, but + their geometric interpretation is not guaranteed to match a + continuous ring on the grid; + - no clockwise or other within-ring ordering is part of the + public API. + Returns ------- If ``copy = True``, returns a :class:`~squidpy.gr.SpatialNeighborsResult` with the spatial connectivities and distances matrices. @@ -140,366 +347,472 @@ def spatial_neighbors( - :attr:`anndata.AnnData.obsp` ``['{{key_added}}_distances']`` - the spatial distances. - :attr:`anndata.AnnData.uns` ``['{{key_added}}']`` - :class:`dict` containing parameters. """ - if isinstance(adata, SpatialData): - assert elements_to_coordinate_systems is not None, ( - "Since `adata` is a :class:`spatialdata.SpatialData`, `elements_to_coordinate_systems` must not be `None`." - ) - assert table_key is not None, ( - "Since `adata` is a :class:`spatialdata.SpatialData`, `table_key` must not be `None`." - ) - elements, table = match_element_to_table(adata, list(elements_to_coordinate_systems), table_key) - assert table.obs_names.equals(adata.tables[table_key].obs_names), ( - "The spatialdata table must annotate all elements keys. Some elements are missing, please check the `elements_to_coordinate_systems` dictionary." - ) - regions, region_key, instance_key = get_table_keys(adata.tables[table_key]) - regions = [regions] if isinstance(regions, str) else regions - ordered_regions_in_table = adata.tables[table_key].obs[region_key].unique() - - # TODO: remove this after https://github.com/scverse/spatialdata/issues/614 - remove_centroids = {} - elem_instances = [] - for e in regions: - schema = get_model(elements[e]) - element_instances = get_element_instances(elements[e]).to_series() - if np.isin(0, element_instances.values) and (schema in (Labels2DModel, Labels3DModel)): - element_instances = element_instances.drop(index=0) - remove_centroids[e] = True - else: - remove_centroids[e] = False - elem_instances.append(element_instances) - - element_instances = pd.concat(elem_instances) - if (not np.all(element_instances.values == adata.tables[table_key].obs[instance_key].values)) or ( - not np.all(ordered_regions_in_table == regions) - ): - raise ValueError( - "The spatialdata table must annotate all elements keys. Some elements are missing or not ordered correctly, please check the `elements_to_coordinate_systems` dictionary." - ) - centroids = [] - for region_ in ordered_regions_in_table: - cs = elements_to_coordinate_systems[region_] - centroid = get_centroids(adata[region_], coordinate_system=cs)[["x", "y"]].compute() - - # TODO: remove this after https://github.com/scverse/spatialdata/issues/614 - if remove_centroids[region_]: - centroid = centroid[1:].copy() - centroids.append(centroid) - - adata.tables[table_key].obsm[spatial_key] = np.concatenate(centroids) - adata = adata.tables[table_key] - library_key = region_key - - assert_positive(n_rings, name="n_rings") - assert_positive(n_neighs, name="n_neighs") - _assert_spatial_basis(adata, spatial_key) - - transform = Transform.NONE if transform is None else Transform(transform) - if coord_type is None: - if radius is not None: - logg.warning( - f"Graph creation with `radius` is only available when `coord_type = {CoordType.GENERIC!r}` specified. " - f"Ignoring parameter `radius = {radius}`." - ) - coord_type = CoordType.GRID if Key.uns.spatial in adata.uns else CoordType.GENERIC - else: - coord_type = CoordType(coord_type) - - if library_key is not None: - _assert_categorical_obs(adata, key=library_key) - libs = adata.obs[library_key].cat.categories - make_index_unique(adata.obs_names) - else: - libs = [None] - - start = logg.info( - f"Creating graph using `{coord_type}` coordinates and `{transform}` transform and `{len(libs)}` libraries." + warnings.warn( + "Calling `spatial_neighbors` is deprecated and will be removed in squidpy " + "v1.7.0. Use `spatial_neighbors_knn`, `spatial_neighbors_radius`, " + "`spatial_neighbors_delaunay`, `spatial_neighbors_grid`, or " + "`spatial_neighbors_from_builder` instead.", + FutureWarning, + stacklevel=2, ) - _build_fun = partial( - _spatial_neighbor, + adata, library_key = _prepare_spatial_neighbors_input( + adata, spatial_key=spatial_key, + elements_to_coordinate_systems=elements_to_coordinate_systems, + table_key=table_key, + library_key=library_key, + ) + builder = _resolve_graph_builder( coord_type=coord_type, n_neighs=n_neighs, radius=radius, delaunay=delaunay, n_rings=n_rings, + percentile=percentile, transform=transform, set_diag=set_diag, - percentile=percentile, + has_spatial_uns=Key.uns.spatial in adata.uns, ) - if library_key is not None: - mats: list[tuple[spmatrix, spmatrix]] = [] - ixs: list[int] = [] - for lib in libs: - ixs.extend(np.where(adata.obs[library_key] == lib)[0]) - mats.append(_build_fun(adata[adata.obs[library_key] == lib])) - ixs = cast(list[int], np.argsort(ixs).tolist()) - Adj = block_diag([m[0] for m in mats], format="csr")[ixs, :][:, ixs] - Dst = block_diag([m[1] for m in mats], format="csr")[ixs, :][:, ixs] - else: - Adj, Dst = _build_fun(adata) + return _run_spatial_neighbors( + adata, builder, spatial_key=spatial_key, library_key=library_key, key_added=key_added, copy=copy + ) - neighs_key = Key.uns.spatial_neighs(key_added) - conns_key = Key.obsp.spatial_conn(key_added) - dists_key = Key.obsp.spatial_dist(key_added) - neighbors_dict = { - "connectivities_key": conns_key, - "distances_key": dists_key, - "params": { - "n_neighbors": n_neighs, - "coord_type": coord_type.v, - "radius": radius, - "transform": transform.v, - }, - } +@d.dedent +def spatial_neighbors_from_builder( + data: AnnData | SpatialData, + builder: GraphBuilder[Any, Any], + *, + spatial_key: str = Key.obsm.spatial, + elements_to_coordinate_systems: dict[str, str] | None = None, + table_key: str | None = None, + library_key: str | None = None, + key_added: str = "spatial", + copy: bool = False, +) -> SpatialNeighborsResult | None: + """Create a graph from spatial coordinates using an explicit builder instance. - if copy: - return SpatialNeighborsResult(connectivities=Adj, distances=Dst) + This function is the bridge between the high-level API (e.g., + :func:`spatial_neighbors_knn`, :func:`spatial_neighbors_radius`) and advanced + customization via builder classes. Use this when you need to: - _save_data(adata, attr="obsp", key=conns_key, data=Adj) - _save_data(adata, attr="obsp", key=dists_key, data=Dst, prefix=False) - _save_data(adata, attr="uns", key=neighs_key, data=neighbors_dict, prefix=False, time=start) + - Stack or chain builder behaviors + - Pass pre-configured builder instances multiple times + - Implement custom builders (see :doc:`/extensibility`) + Parameters + ---------- + %(adata)s + builder + Graph construction strategy to execute. Built-in builders subclass + {{class}}`~squidpy.gr.neighbors.GraphBuilderCSR`, while custom backends + can implement the more generic + {{class}}`~squidpy.gr.neighbors.GraphBuilder` interface directly. + Reusable post-build operations are also exposed via + :class:`~squidpy.gr.neighbors.DistanceIntervalPostprocessor`, + :class:`~squidpy.gr.neighbors.PercentilePostprocessor`, and + :class:`~squidpy.gr.neighbors.TransformPostprocessor`. + Custom builders only need to implement multi-library support when using + ``library_key``; otherwise leaving + :meth:`~squidpy.gr.neighbors.GraphBuilder.combine` unimplemented is fine. + %(spatial_key)s + %(sdata_params)s + %(library_key)s + key_added + Key which controls where the results are saved if ``copy = False``. + %(copy)s -def _spatial_neighbor( - adata: AnnData, + Returns + ------- + %(spatial_neighbors_returns)s + + See Also + -------- + spatial_neighbors_knn : k-nearest-neighbor graphs (wraps :class:`~squidpy.gr.neighbors.KNNBuilder`). + spatial_neighbors_radius : radius-based graphs (wraps :class:`~squidpy.gr.neighbors.RadiusBuilder`). + spatial_neighbors_delaunay : Delaunay triangulation graphs (wraps :class:`~squidpy.gr.neighbors.DelaunayBuilder`). + spatial_neighbors_grid : grid-based graphs (wraps :class:`~squidpy.gr.neighbors.GridBuilder`). + squidpy.gr.neighbors.GraphBuilder : Base builder interface. Inherit from this or :class:`~squidpy.gr.neighbors.GraphBuilderCSR` to implement custom graph construction. + """ + adata, library_key = _prepare_spatial_neighbors_input( + data, + spatial_key=spatial_key, + elements_to_coordinate_systems=elements_to_coordinate_systems, + table_key=table_key, + library_key=library_key, + ) + return _run_spatial_neighbors( + adata, + builder, + spatial_key=spatial_key, + library_key=library_key, + key_added=key_added, + copy=copy, + ) + + +def _prepare_spatial_neighbors_input( + data: AnnData | SpatialData, + *, + spatial_key: str, + elements_to_coordinate_systems: dict[str, str] | None, + table_key: str | None, + library_key: str | None, +) -> tuple[AnnData, str | None]: + """Resolve input data and validate the requested spatial basis.""" + adata, library_key = _resolve_spatial_data( + data, + spatial_key=spatial_key, + elements_to_coordinate_systems=elements_to_coordinate_systems, + table_key=table_key, + library_key=library_key, + ) + _assert_spatial_basis(adata, spatial_key) + return adata, library_key + + +@d.dedent +def spatial_neighbors_knn( + data: AnnData | SpatialData, + *, spatial_key: str = Key.obsm.spatial, - coord_type: str | CoordType | None = None, + elements_to_coordinate_systems: dict[str, str] | None = None, + table_key: str | None = None, + library_key: str | None = None, n_neighs: int = 6, - radius: float | tuple[float, float] | None = None, - delaunay: bool = False, - n_rings: int = 1, + percentile: float | None = None, transform: str | Transform | None = None, set_diag: bool = False, - percentile: float | None = None, -) -> tuple[csr_matrix, csr_matrix]: - coords = adata.obsm[spatial_key] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", SparseEfficiencyWarning) - if coord_type == CoordType.GRID: - Adj, Dst = _build_grid( - coords, - n_neighs=n_neighs, - n_rings=n_rings, - delaunay=delaunay, - set_diag=set_diag, - ) - elif coord_type == CoordType.GENERIC: - Adj, Dst = _build_connectivity( - coords, - n_neighs=n_neighs, - radius=radius, - delaunay=delaunay, - return_distance=True, - set_diag=set_diag, - ) - else: - raise NotImplementedError(f"Coordinate type `{coord_type}` is not yet implemented.") - - if coord_type == CoordType.GENERIC and isinstance(radius, Iterable): - minn, maxx = sorted(radius)[:2] - mask = (Dst.data < minn) | (Dst.data > maxx) - a_diag = Adj.diagonal() - - Dst.data[mask] = 0.0 - Adj.data[mask] = 0.0 - Adj.setdiag(a_diag) - - if percentile is not None and coord_type == CoordType.GENERIC: - threshold = np.percentile(Dst.data, percentile) - Adj[Dst > threshold] = 0.0 - Dst[Dst > threshold] = 0.0 - - Adj.eliminate_zeros() - Dst.eliminate_zeros() - - # check transform - if transform == Transform.SPECTRAL: - Adj = _transform_a_spectral(Adj) - elif transform == Transform.COSINE: - Adj = _transform_a_cosine(Adj) - elif transform == Transform.NONE: - pass - else: - raise NotImplementedError(f"Transform `{transform}` is not yet implemented.") - - return Adj, Dst + key_added: str = "spatial", + copy: bool = False, +) -> SpatialNeighborsResult | None: + """Create a k-nearest-neighbor graph from spatial coordinates. + Each observation is connected to its ``n_neighs`` nearest observations in + Euclidean space. This mode is typically most useful for continuous + coordinates, where you want to control neighborhood size directly. -def _build_grid( - coords: NDArrayA, - n_neighs: int, - n_rings: int, - delaunay: bool = False, - set_diag: bool = False, -) -> tuple[csr_matrix, csr_matrix]: - if n_rings > 1: - Adj: csr_matrix = _build_connectivity( - coords, - n_neighs=n_neighs, - neigh_correct=True, - set_diag=True, - delaunay=delaunay, - return_distance=False, - ) - Res, Walk = Adj, Adj - for i in range(n_rings - 1): - Walk = Walk @ Adj - Walk[Res.nonzero()] = 0.0 - Walk.eliminate_zeros() - Walk.data[:] = i + 2.0 - Res = Res + Walk - Adj = Res - Adj.setdiag(float(set_diag)) - Adj.eliminate_zeros() - - Dst = Adj.copy() - Adj.data[:] = 1.0 - else: - Adj = _build_connectivity( - coords, - n_neighs=n_neighs, - neigh_correct=True, - delaunay=delaunay, - set_diag=set_diag, - ) - Dst = Adj.copy() + Parameters + ---------- + %(adata)s + %(spatial_key)s + %(sdata_params)s + %(library_key)s + n_neighs + Number of nearest neighbors. Defaults to ``6``. Smaller values produce a + sparser, more local graph; larger values connect broader neighborhoods. + %(graph_common_params)s + %(copy)s - Dst.setdiag(0.0) + Returns + ------- + %(spatial_neighbors_returns)s - return Adj, Dst + See Also + -------- + spatial_neighbors_from_builder : Use :class:`~squidpy.gr.neighbors.KNNBuilder` directly for advanced customization. + squidpy.gr.neighbors.KNNBuilder : k-nearest-neighbor builder class. + """ + transform_enum = Transform.NONE if transform is None else Transform(transform) + builder = KNNBuilder( + n_neighs=n_neighs, + percentile=percentile, + transform=transform_enum, + set_diag=set_diag, + ) + adata, library_key = _prepare_spatial_neighbors_input( + data, + spatial_key=spatial_key, + elements_to_coordinate_systems=elements_to_coordinate_systems, + table_key=table_key, + library_key=library_key, + ) + return _run_spatial_neighbors( + adata, + builder, + spatial_key=spatial_key, + library_key=library_key, + key_added=key_added, + copy=copy, + ) -def _build_connectivity( - coords: NDArrayA, - n_neighs: int, - radius: float | tuple[float, float] | None = None, - delaunay: bool = False, - neigh_correct: bool = False, +@d.dedent +def spatial_neighbors_radius( + data: AnnData | SpatialData, + *, + spatial_key: str = Key.obsm.spatial, + elements_to_coordinate_systems: dict[str, str] | None = None, + table_key: str | None = None, + library_key: str | None = None, + radius: float | tuple[float, float] = 1.0, + percentile: float | None = None, + transform: str | Transform | None = None, set_diag: bool = False, - return_distance: bool = False, -) -> csr_matrix | tuple[csr_matrix, csr_matrix]: - N = coords.shape[0] - if delaunay: - tri = Delaunay(coords) - indptr, indices = tri.vertex_neighbor_vertices - Adj = csr_matrix((np.ones_like(indices, dtype=np.float32), indices, indptr), shape=(N, N)) - - if return_distance: - # fmt: off - dists = np.array(list(chain(*( - euclidean_distances(coords[indices[indptr[i] : indptr[i + 1]], :], coords[np.newaxis, i, :]) - for i in range(N) - if len(indices[indptr[i] : indptr[i + 1]]) - )))).squeeze() - Dst = csr_matrix((dists, indices, indptr), shape=(N, N)) - # fmt: on - else: - r = 1 if radius is None else radius if isinstance(radius, int | float) else max(radius) - tree = NearestNeighbors(n_neighbors=n_neighs, radius=r, metric="euclidean") - tree.fit(coords) - - if radius is None: - dists, col_indices = tree.kneighbors() - dists, col_indices = dists.reshape(-1), col_indices.reshape(-1) - row_indices = np.repeat(np.arange(N), n_neighs) - if neigh_correct: - dist_cutoff = np.median(dists) * 1.3 # there's a small amount of sway - mask = dists < dist_cutoff - row_indices, col_indices, dists = ( - row_indices[mask], - col_indices[mask], - dists[mask], - ) - else: - dists, col_indices = tree.radius_neighbors() - row_indices = np.repeat(np.arange(N), [len(x) for x in col_indices]) - dists = np.concatenate(dists) - col_indices = np.concatenate(col_indices) - - Adj = csr_matrix( - (np.ones_like(row_indices, dtype=np.float32), (row_indices, col_indices)), - shape=(N, N), - ) - if return_distance: - Dst = csr_matrix((dists, (row_indices, col_indices)), shape=(N, N)) + key_added: str = "spatial", + copy: bool = False, +) -> SpatialNeighborsResult | None: + """Create a radius-based graph from spatial coordinates. - # radius-based filtering needs same indices/indptr: do not remove 0s - Adj.setdiag(1.0 if set_diag else Adj.diagonal()) - if return_distance: - Dst.setdiag(0.0) - return Adj, Dst + Two observations are connected when their Euclidean distance falls within the + requested radius. This mode is useful when a physical interaction scale is + more meaningful than a fixed number of neighbors. - return Adj + Parameters + ---------- + %(adata)s + %(spatial_key)s + %(sdata_params)s + %(library_key)s + radius + Neighborhood radius. If a :class:`tuple`, the graph is built with the + maximum radius and then pruned to the interval ``[min(radius), max(radius)]``. + In practice, a single value defines a disk around each observation, + whereas a tuple defines an annulus by keeping only edges within the + specified distance interval. + %(graph_common_params)s + %(copy)s + Returns + ------- + %(spatial_neighbors_returns)s -@njit -def _csr_bilateral_diag_scale_helper( - mat: csr_array | csr_matrix, - degrees: NDArrayA, -) -> NDArrayA: + See Also + -------- + spatial_neighbors_from_builder : Use :class:`~squidpy.gr.neighbors.RadiusBuilder` directly for advanced customization. + squidpy.gr.neighbors.RadiusBuilder : radius-based builder class. """ - Return an array F aligned with CSR non-zeros such that - F[k] = d[i] * data[k] * d[j] for the k-th non-zero (i, j) in CSR order. + transform_enum = Transform.NONE if transform is None else Transform(transform) + builder = RadiusBuilder( + radius=radius, + percentile=percentile, + transform=transform_enum, + set_diag=set_diag, + ) + adata, library_key = _prepare_spatial_neighbors_input( + data, + spatial_key=spatial_key, + elements_to_coordinate_systems=elements_to_coordinate_systems, + table_key=table_key, + library_key=library_key, + ) + return _run_spatial_neighbors( + adata, + builder, + spatial_key=spatial_key, + library_key=library_key, + key_added=key_added, + copy=copy, + ) + + +@d.dedent +def spatial_neighbors_delaunay( + data: AnnData | SpatialData, + *, + spatial_key: str = Key.obsm.spatial, + elements_to_coordinate_systems: dict[str, str] | None = None, + table_key: str | None = None, + library_key: str | None = None, + radius: tuple[float, float] | None = None, + percentile: float | None = None, + transform: str | Transform | None = None, + set_diag: bool = False, + key_added: str = "spatial", + copy: bool = False, +) -> SpatialNeighborsResult | None: + """Create a Delaunay triangulation graph from spatial coordinates. + + Delaunay triangulation connects observations into triangles such that no + other observation lies inside the circumcircle of each triangle. In + practice, this yields an adaptive geometry-driven graph rather than one + based on a fixed ``k`` or radius, and ``dst`` stores Euclidean edge lengths. Parameters ---------- - - data : array of float - CSR `data` (non-zero values). - indices : array of int - CSR `indices` (column indices). - indptr : array of int - CSR `indptr` (row pointer). - degrees : array of float, shape (n,) - Diagonal scaling vector. + %(adata)s + %(spatial_key)s + %(sdata_params)s + %(library_key)s + radius + If a :class:`tuple`, used as a post-construction pruning interval + ``[min(radius), max(radius)]``. This does not change the triangulation + itself; it only removes Delaunay edges whose Euclidean lengths fall + outside the interval. + %(graph_common_params)s + %(copy)s Returns ------- - array of float - Length equals len(data). Entry-wise factors d_i * d_j * data[k] - """ - - res = np.empty_like(mat.data, dtype=np.float32) - for i in prange(len(mat.indptr) - 1): - ixs = mat.indices[mat.indptr[i] : mat.indptr[i + 1]] - res[mat.indptr[i] : mat.indptr[i + 1]] = degrees[i] * degrees[ixs] * mat.data[mat.indptr[i] : mat.indptr[i + 1]] + %(spatial_neighbors_returns)s - return res + See Also + -------- + spatial_neighbors_from_builder : Use :class:`~squidpy.gr.neighbors.DelaunayBuilder` directly for advanced customization. + squidpy.gr.neighbors.DelaunayBuilder : Delaunay triangulation builder class. + """ + transform_enum = Transform.NONE if transform is None else Transform(transform) + builder = DelaunayBuilder( + radius=radius, + percentile=percentile, + transform=transform_enum, + set_diag=set_diag, + ) + adata, library_key = _prepare_spatial_neighbors_input( + data, + spatial_key=spatial_key, + elements_to_coordinate_systems=elements_to_coordinate_systems, + table_key=table_key, + library_key=library_key, + ) + return _run_spatial_neighbors( + adata, + builder, + spatial_key=spatial_key, + library_key=library_key, + key_added=key_added, + copy=copy, + ) -def symmetric_normalize_csr(adj: spmatrix) -> csr_matrix: - """ - Return D^{-1/2} * A * D^{-1/2}, where D = diag(degrees(A)) and A = adj. +@d.dedent +def spatial_neighbors_grid( + data: AnnData | SpatialData, + *, + spatial_key: str = Key.obsm.spatial, + elements_to_coordinate_systems: dict[str, str] | None = None, + table_key: str | None = None, + library_key: str | None = None, + n_neighs: int = 6, + n_rings: int = 1, + delaunay: bool = False, + transform: str | Transform | None = None, + set_diag: bool = False, + key_added: str = "spatial", + copy: bool = False, +) -> SpatialNeighborsResult | None: + """Create a grid-based graph from spatial coordinates. + This is the mode used for Visium-like grid coordinates. + It assumes observations lie on an approximately regular lattice, so it is + usually not appropriate for continuous coordinates such as Xenium point + clouds. On irregular coordinates, the resulting graph and ring distances may + not have a meaningful grid interpretation. Parameters ---------- - adj : scipy.sparse.csr_matrix + %(adata)s + %(spatial_key)s + %(sdata_params)s + %(library_key)s + n_neighs + Number of neighboring tiles used to form the base grid connectivity. + Defaults to ``6``. On a Visium-like hexagonal grid, ``6`` corresponds to + the immediate surrounding spots, while smaller values such as ``3`` make + the first-ring graph deliberately sparser. + n_rings + Number of rings of neighbors. Defaults to ``1``. ``n_rings=1`` keeps + only immediate neighbors; larger values add progressively more distant + shells and encode the shell number in ``dst``. For example, + ``n_neighs=3`` with ``n_rings=2`` on a Visium-like grid starts from a + sparse three-neighbor base graph and then adds a second graph-distance + ring relative to that base connectivity. + delaunay + Whether to derive the base grid connectivity from a Delaunay triangulation. + This is still grid mode: unlike :func:`spatial_neighbors_delaunay`, the + resulting distance matrix encodes grid or ring distances rather than + Euclidean edge lengths. In practice, this changes how the first-ring + connectivity is inferred, but not the meaning of the resulting + distances. + %(graph_common_params)s + %(copy)s Returns ------- - scipy.sparse.csr_matrix + %(spatial_neighbors_returns)s + + See Also + -------- + spatial_neighbors_from_builder : Use :class:`~squidpy.gr.neighbors.GridBuilder` directly for advanced customization. + squidpy.gr.neighbors.GridBuilder : grid-based builder class. """ - degrees = np.squeeze(np.array(np.sqrt(1.0 / fau_stats.sum(adj, axis=0)))) - if adj.shape[0] != len(degrees): - raise ValueError("len(degrees) must equal number of rows of adj") - res_data = _csr_bilateral_diag_scale_helper(adj, degrees) - return csr_matrix((res_data, adj.indices, adj.indptr), shape=adj.shape) + assert_positive(n_rings, name="n_rings") + assert_positive(n_neighs, name="n_neighs") + transform_enum = Transform.NONE if transform is None else Transform(transform) + builder = GridBuilder( + n_neighs=n_neighs, + n_rings=n_rings, + delaunay=delaunay, + transform=transform_enum, + set_diag=set_diag, + ) + adata, library_key = _prepare_spatial_neighbors_input( + data, + spatial_key=spatial_key, + elements_to_coordinate_systems=elements_to_coordinate_systems, + table_key=table_key, + library_key=library_key, + ) + return _run_spatial_neighbors( + adata, + builder, + spatial_key=spatial_key, + library_key=library_key, + key_added=key_added, + copy=copy, + ) -def _transform_a_spectral(a: spmatrix) -> spmatrix: - if not isspmatrix_csr(a): - a = a.tocsr() - if not a.nnz: - return a +def _run_spatial_neighbors( + adata: AnnData, + builder: GraphBuilder[Any, Any], + *, + spatial_key: str = Key.obsm.spatial, + library_key: str | None = None, + key_added: str = "spatial", + copy: bool = False, +) -> SpatialNeighborsResult | None: + """Shared core: build the graph from a resolved builder and save results.""" + if library_key is not None: + _assert_categorical_obs(adata, key=library_key) + libs = adata.obs[library_key].cat.categories + make_index_unique(adata.obs_names) + else: + libs = [None] - return symmetric_normalize_csr(a) + start = logg.info(f"Creating graph using `{builder.transform}` transform and `{len(libs)}` libraries.") + if library_key is not None: + mats: list[tuple[Any, Any]] = [] + ixs: list[int] = [] + for lib in libs: + ixs.extend(np.where(adata.obs[library_key] == lib)[0]) + mats.append(builder.build(adata[adata.obs[library_key] == lib].obsm[spatial_key])) + adj, dst = builder.combine(mats, ixs) + else: + adj, dst = builder.build(adata.obsm[spatial_key]) + neighs_key = Key.uns.spatial_neighs(key_added) + conns_key = Key.obsp.spatial_conn(key_added) + dists_key = Key.obsp.spatial_dist(key_added) -def _transform_a_cosine(a: spmatrix) -> spmatrix: - return cosine_similarity(a, dense_output=False) + neighbors_dict = { + "connectivities_key": conns_key, + "distances_key": dists_key, + "params": { + "n_neighbors": getattr(builder, "n_neighs", 6), + "radius": getattr(builder, "radius", None), + "transform": builder.transform.v, + }, + } + + if copy: + return SpatialNeighborsResult(connectivities=adj, distances=dst) + + _save_data(adata, attr="obsp", key=conns_key, data=adj) + _save_data(adata, attr="obsp", key=dists_key, data=dst, prefix=False) + _save_data(adata, attr="uns", key=neighs_key, data=neighbors_dict, prefix=False, time=start) + return None @d.dedent @@ -560,11 +873,11 @@ def mask_graph( # get elements table = sdata.tables[table_key] coords = table.obsm[spatial_key] - Adj = table.obsp[conns_key] - Dst = table.obsp[dists_key] + adj = table.obsp[conns_key] + dst = table.obsp[dists_key] # convert edges to lines - lines_coords, idx_out = _get_lines_coords(Adj.indices, Adj.indptr, coords) + lines_coords, idx_out = _get_lines_coords(adj.indices, adj.indptr, coords) lines_coords, idx_out = np.array(lines_coords), np.array(idx_out) lines_df = gpd.GeoDataFrame(geometry=list(map(LineString, lines_coords))) @@ -578,12 +891,12 @@ def mask_graph( filt_idx_out = idx_out[filt_lines] # filter connectivities - Adj[filt_idx_out[:, 0], filt_idx_out[:, 1]] = 0 - Adj.eliminate_zeros() + adj[filt_idx_out[:, 0], filt_idx_out[:, 1]] = 0 + adj.eliminate_zeros() # filter_distances - Dst[filt_idx_out[:, 0], filt_idx_out[:, 1]] = 0 - Dst.eliminate_zeros() + dst[filt_idx_out[:, 0], filt_idx_out[:, 1]] = 0 + dst.eliminate_zeros() mask_conns_key = f"{key_added}_{conns_key}" mask_dists_key = f"{key_added}_{dists_key}" @@ -600,11 +913,11 @@ def mask_graph( } if copy: - return Adj, Dst + return adj, dst # save back to spatialdata - _save_data(table, attr="obsp", key=mask_conns_key, data=Adj) - _save_data(table, attr="obsp", key=mask_dists_key, data=Dst, prefix=False) + _save_data(table, attr="obsp", key=mask_conns_key, data=adj) + _save_data(table, attr="obsp", key=mask_dists_key, data=dst, prefix=False) _save_data(table, attr="uns", key=mask_neighs_key, data=neighbors_dict, prefix=False) diff --git a/src/squidpy/gr/neighbors.py b/src/squidpy/gr/neighbors.py new file mode 100644 index 000000000..422fbf77a --- /dev/null +++ b/src/squidpy/gr/neighbors.py @@ -0,0 +1,494 @@ +"""Graph construction strategies for spatial neighbor graphs. + +See the :doc:`/extensibility` guide for how to implement a custom builder. +""" + +from __future__ import annotations + +import warnings +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from itertools import chain +from typing import Generic, TypeVar, cast + +import numpy as np +from fast_array_utils import stats as fau_stats +from numba import njit, prange +from scipy.sparse import ( + SparseEfficiencyWarning, + block_diag, + csr_array, + csr_matrix, + isspmatrix_csr, + spmatrix, +) +from scipy.spatial import Delaunay +from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances +from sklearn.neighbors import NearestNeighbors + +from squidpy._constants._constants import Transform +from squidpy._utils import NDArrayA +from squidpy._validators import assert_positive + +__all__ = [ + "GraphMatrixT", + "GraphBuilder", + "GraphBuilderCSR", + "GraphPostprocessor", + "DistanceIntervalPostprocessor", + "PercentilePostprocessor", + "TransformPostprocessor", + "KNNBuilder", + "RadiusBuilder", + "DelaunayBuilder", + "GridBuilder", +] + + +CoordT = TypeVar("CoordT") +GraphMatrixT = TypeVar("GraphMatrixT") +GraphPostprocessor = Callable[[GraphMatrixT, GraphMatrixT], tuple[GraphMatrixT, GraphMatrixT]] + + +class GraphBuilder(ABC, Generic[CoordT, GraphMatrixT]): + """Base class for spatial graph construction strategies. + + Custom builders must implement :meth:`build_graph`. Overriding + :meth:`postprocessors` and :meth:`combine` is optional. Postprocessors can + be provided directly via ``__init__`` or by overriding + :meth:`postprocessors`. + """ + + def __init__( + self, + transform: str | Transform | None = None, + set_diag: bool = False, + percentile: float | None = None, + postprocessors: Sequence[GraphPostprocessor[GraphMatrixT]] = (), + ) -> None: + self.transform = Transform.NONE if transform is None else Transform(transform) + self.set_diag = set_diag + self.percentile = percentile + self._postprocessors: list[GraphPostprocessor[GraphMatrixT]] = list(postprocessors) + + def build(self, coords: CoordT) -> tuple[GraphMatrixT, GraphMatrixT]: + adj, dst = self.build_graph(coords) + for postprocessor in self.postprocessors(): + adj, dst = postprocessor(adj, dst) + return adj, dst + + @abstractmethod + def build_graph(self, coords: CoordT) -> tuple[GraphMatrixT, GraphMatrixT]: + """Construct raw adjacency and distance matrices.""" + + def postprocessors(self) -> Sequence[GraphPostprocessor[GraphMatrixT]]: + """Return post-build processing steps for ``(adj, dst)``.""" + return self._postprocessors + + def combine( + self, + mats: Sequence[tuple[GraphMatrixT, GraphMatrixT]], + ixs: Sequence[int], + ) -> tuple[GraphMatrixT, GraphMatrixT]: + """Combine per-library results into a single graph. + + Override this only if the builder should support multi-library graph + construction via ``library_key``. + """ + raise NotImplementedError("Using `library_key` with this graph builder is not implemented yet.") + + +class GraphBuilderCSR(GraphBuilder[NDArrayA, csr_matrix], ABC): + """CSR-based graph construction strategy. + + Specializes :class:`GraphBuilder` for sparse CSR matrix output. Adds + SparseEfficiencyWarning suppression and multi-library ``library_key`` + combination. Built-in concrete builders + (:class:`KNNBuilder`, :class:`RadiusBuilder`, :class:`DelaunayBuilder`, :class:`GridBuilder`) + inherit from this class and declare their postprocessors explicitly in + ``__init__`` using the reusable public postprocessor classes. + + Subclass this (not the generic :class:`GraphBuilder`) when implementing a builder + that returns CSR matrices. + + See Also + -------- + GraphBuilder : Generic builder interface for custom coordinate/matrix types. + KNNBuilder : Example of a concrete CSR-based builder. + """ + + def build(self, coords: NDArrayA) -> tuple[csr_matrix, csr_matrix]: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", SparseEfficiencyWarning) + return super().build(coords) + + @abstractmethod + def build_graph(self, coords: NDArrayA) -> tuple[csr_matrix, csr_matrix]: + """Construct raw adjacency and distance matrices.""" + + def combine( + self, + mats: Sequence[tuple[csr_matrix, csr_matrix]], + ixs: Sequence[int], + ) -> tuple[csr_matrix, csr_matrix]: + order = cast(list[int], np.argsort(ixs).tolist()) + adj = block_diag([m[0] for m in mats], format="csr")[order, :][:, order] + dst = block_diag([m[1] for m in mats], format="csr")[order, :][:, order] + return cast(csr_matrix, adj), cast(csr_matrix, dst) + + +class KNNBuilder(GraphBuilderCSR): + """Build a generic k-nearest-neighbor spatial graph. + + Each observation is connected to its k nearest neighbors. See + :func:`~squidpy.gr.spatial_neighbors_knn` for the user-facing API or + :func:`~squidpy.gr.spatial_neighbors_from_builder` for direct builder usage. + """ + + def __init__( + self, + n_neighs: int = 6, + transform: str | Transform | None = None, + set_diag: bool = False, + percentile: float | None = None, + ) -> None: + assert_positive(n_neighs, name="n_neighs") + postprocessors: list[GraphPostprocessor[csr_matrix]] = [] + if percentile is not None: + postprocessors.append(PercentilePostprocessor(percentile)) + postprocessors.append(TransformPostprocessor(Transform.NONE if transform is None else Transform(transform))) + super().__init__( + transform=transform, + set_diag=set_diag, + percentile=percentile, + postprocessors=postprocessors, + ) + self.n_neighs = n_neighs + + def build_graph(self, coords: NDArrayA) -> tuple[csr_matrix, csr_matrix]: + N = coords.shape[0] + tree = NearestNeighbors(n_neighbors=self.n_neighs, radius=1, metric="euclidean") + tree.fit(coords) + + dists, col_indices = tree.kneighbors() + dists, col_indices = dists.reshape(-1), col_indices.reshape(-1) + row_indices = np.repeat(np.arange(N), self.n_neighs) + + adj = csr_matrix( + (np.ones_like(row_indices, dtype=np.float32), (row_indices, col_indices)), + shape=(N, N), + ) + dst = csr_matrix((dists, (row_indices, col_indices)), shape=(N, N)) + + adj.setdiag(1.0 if self.set_diag else adj.diagonal()) + dst.setdiag(0.0) + return adj, dst + + +class RadiusBuilder(GraphBuilderCSR): + """Build a generic radius-based spatial graph. + + Two observations are connected when their Euclidean distance falls within + the specified radius. See :func:`~squidpy.gr.spatial_neighbors_radius` for the + user-facing API or :func:`~squidpy.gr.spatial_neighbors_from_builder` for + direct builder usage. + """ + + def __init__( + self, + radius: float | tuple[float, float], + transform: str | Transform | None = None, + set_diag: bool = False, + percentile: float | None = None, + ) -> None: + postprocessors: list[GraphPostprocessor[csr_matrix]] = [] + if isinstance(radius, tuple): + postprocessors.append(DistanceIntervalPostprocessor(tuple(sorted(radius)))) + if percentile is not None: + postprocessors.append(PercentilePostprocessor(percentile)) + postprocessors.append(TransformPostprocessor(Transform.NONE if transform is None else Transform(transform))) + super().__init__( + transform=transform, + set_diag=set_diag, + percentile=percentile, + postprocessors=postprocessors, + ) + self.radius = radius + + def build_graph(self, coords: NDArrayA) -> tuple[csr_matrix, csr_matrix]: + N = coords.shape[0] + r = self.radius if isinstance(self.radius, int | float) else max(self.radius) + tree = NearestNeighbors(radius=r, metric="euclidean") + tree.fit(coords) + + dists, col_indices = tree.radius_neighbors() + row_indices = np.repeat(np.arange(N), [len(x) for x in col_indices]) + dists = np.concatenate(dists) + col_indices = np.concatenate(col_indices) + + adj = csr_matrix( + (np.ones_like(row_indices, dtype=np.float32), (row_indices, col_indices)), + shape=(N, N), + ) + dst = csr_matrix((dists, (row_indices, col_indices)), shape=(N, N)) + + adj.setdiag(1.0 if self.set_diag else adj.diagonal()) + dst.setdiag(0.0) + return adj, dst + + +class DelaunayBuilder(GraphBuilderCSR): + """Build a generic point-cloud graph from a Delaunay triangulation. + + Delaunay triangulation connects observations into triangles such that no + other observation lies inside the circumcircle of each triangle. Unlike + ``GridBuilder(delaunay=True)``, this builder uses geometry-based connectivity + and stores real Euclidean edge lengths. + + See :func:`~squidpy.gr.spatial_neighbors_delaunay` for the user-facing API or + :func:`~squidpy.gr.spatial_neighbors_from_builder` for direct builder usage. + """ + + def __init__( + self, + radius: float | tuple[float, float] | None = None, + transform: str | Transform | None = None, + set_diag: bool = False, + percentile: float | None = None, + ) -> None: + postprocessors: list[GraphPostprocessor[csr_matrix]] = [] + if isinstance(radius, tuple): + postprocessors.append(DistanceIntervalPostprocessor(tuple(sorted(radius)))) + if percentile is not None: + postprocessors.append(PercentilePostprocessor(percentile)) + postprocessors.append(TransformPostprocessor(Transform.NONE if transform is None else Transform(transform))) + super().__init__( + transform=transform, + set_diag=set_diag, + percentile=percentile, + postprocessors=postprocessors, + ) + self.radius = radius + + def build_graph(self, coords: NDArrayA) -> tuple[csr_matrix, csr_matrix]: + N = coords.shape[0] + tri = Delaunay(coords) + indptr, indices = tri.vertex_neighbor_vertices + adj = csr_matrix((np.ones_like(indices, dtype=np.float32), indices, indptr), shape=(N, N)) + + # fmt: off + dists = np.array(list(chain(*( + euclidean_distances(coords[indices[indptr[i] : indptr[i + 1]], :], coords[np.newaxis, i, :]) + for i in range(N) + if len(indices[indptr[i] : indptr[i + 1]]) + )))).squeeze() + # fmt: on + dst = csr_matrix((dists, indices, indptr), shape=(N, N)) + + adj.setdiag(1.0 if self.set_diag else adj.diagonal()) + dst.setdiag(0.0) + return adj, dst + + +class GridBuilder(GraphBuilderCSR): + """Build a grid-based spatial graph. + + Assumes observations lie on an approximately regular lattice (e.g., Visium). + When ``delaunay=True``, Delaunay triangulation is used only to derive the + base connectivity; the distance matrix still encodes grid/ring distances, + not Euclidean lengths. + + See :func:`~squidpy.gr.spatial_neighbors_grid` for the user-facing API or + :func:`~squidpy.gr.spatial_neighbors_from_builder` for direct builder usage. + """ + + def __init__( + self, + n_neighs: int = 6, + n_rings: int = 1, + delaunay: bool = False, + transform: str | Transform | None = None, + set_diag: bool = False, + ) -> None: + assert_positive(n_neighs, name="n_neighs") + assert_positive(n_rings, name="n_rings") + postprocessors = [TransformPostprocessor(Transform.NONE if transform is None else Transform(transform))] + super().__init__(transform=transform, set_diag=set_diag, percentile=None, postprocessors=postprocessors) + self.n_neighs = n_neighs + self.n_rings = n_rings + self.delaunay = delaunay + + def build_graph(self, coords: NDArrayA) -> tuple[csr_matrix, csr_matrix]: + if self.n_rings > 1: + adj = self._base_adjacency(coords, set_diag=True) + res, walk = adj, adj + for i in range(self.n_rings - 1): + walk = walk @ adj + walk[res.nonzero()] = 0.0 + walk.eliminate_zeros() + walk.data[:] = i + 2.0 + res = res + walk + adj = res + adj.setdiag(float(self.set_diag)) + adj.eliminate_zeros() + + dst = adj.copy() + adj.data[:] = 1.0 + else: + adj = self._base_adjacency(coords, set_diag=self.set_diag) + dst = adj.copy() + + dst.setdiag(0.0) + return adj, dst + + def _base_adjacency(self, coords: NDArrayA, *, set_diag: bool) -> csr_matrix: + """KNN adjacency with median-distance correction for grid coordinates.""" + N = coords.shape[0] + if self.delaunay: + tri = Delaunay(coords) + indptr, indices = tri.vertex_neighbor_vertices + adj = csr_matrix((np.ones_like(indices, dtype=np.float32), indices, indptr), shape=(N, N)) + else: + tree = NearestNeighbors(n_neighbors=self.n_neighs, radius=1, metric="euclidean") + tree.fit(coords) + dists, col_indices = tree.kneighbors() + dists, col_indices = dists.reshape(-1), col_indices.reshape(-1) + row_indices = np.repeat(np.arange(N), self.n_neighs) + + dist_cutoff = np.median(dists) * 1.3 + mask = dists < dist_cutoff + row_indices, col_indices = row_indices[mask], col_indices[mask] + + adj = csr_matrix( + (np.ones_like(row_indices, dtype=np.float32), (row_indices, col_indices)), + shape=(N, N), + ) + + adj.setdiag(1.0 if set_diag else adj.diagonal()) + return adj + + +# --------------------------------------------------------------------------- +# Private helpers used by the builder classes +# --------------------------------------------------------------------------- + + +def _filter_by_radius_interval( + adj: csr_matrix, + dst: csr_matrix, + radius: tuple[float, float], +) -> None: + minn, maxx = radius + mask = (dst.data < minn) | (dst.data > maxx) + a_diag = adj.diagonal() + + dst.data[mask] = 0.0 + adj.data[mask] = 0.0 + adj.setdiag(a_diag) + + +@dataclass(frozen=True) +class DistanceIntervalPostprocessor: + interval: tuple[float, float] + + def __call__(self, adj: csr_matrix, dst: csr_matrix) -> tuple[csr_matrix, csr_matrix]: + _filter_by_radius_interval(adj, dst, self.interval) + return adj, dst + + +@dataclass(frozen=True) +class PercentilePostprocessor: + percentile: float + + def __call__(self, adj: csr_matrix, dst: csr_matrix) -> tuple[csr_matrix, csr_matrix]: + threshold = np.percentile(dst.data, self.percentile) + adj[dst > threshold] = 0.0 + dst[dst > threshold] = 0.0 + return adj, dst + + +@dataclass(frozen=True) +class TransformPostprocessor: + transform: Transform + + def __call__(self, adj: csr_matrix, dst: csr_matrix) -> tuple[csr_matrix, csr_matrix]: + adj.eliminate_zeros() + dst.eliminate_zeros() + + if self.transform == Transform.SPECTRAL: + return cast(csr_matrix, _transform_a_spectral(adj)), dst + if self.transform == Transform.COSINE: + return cast(csr_matrix, _transform_a_cosine(adj)), dst + if self.transform == Transform.NONE: + return adj, dst + + raise NotImplementedError(f"Transform `{self.transform}` is not yet implemented.") + + +@njit +def _csr_bilateral_diag_scale_helper( + mat: csr_array | csr_matrix, + degrees: NDArrayA, +) -> NDArrayA: + """ + Return an array F aligned with CSR non-zeros such that + F[k] = d[i] * data[k] * d[j] for the k-th non-zero (i, j) in CSR order. + + Parameters + ---------- + + data : array of float + CSR `data` (non-zero values). + indices : array of int + CSR `indices` (column indices). + indptr : array of int + CSR `indptr` (row pointer). + degrees : array of float, shape (n,) + Diagonal scaling vector. + + Returns + ------- + array of float + Length equals len(data). Entry-wise factors d_i * d_j * data[k] + """ + + res = np.empty_like(mat.data, dtype=np.float32) + for i in prange(len(mat.indptr) - 1): + ixs = mat.indices[mat.indptr[i] : mat.indptr[i + 1]] + res[mat.indptr[i] : mat.indptr[i + 1]] = degrees[i] * degrees[ixs] * mat.data[mat.indptr[i] : mat.indptr[i + 1]] + + return res + + +def symmetric_normalize_csr(adj: spmatrix) -> csr_matrix: + """ + Return D^{-1/2} * A * D^{-1/2}, where D = diag(degrees(A)) and A = adj. + + + Parameters + ---------- + adj : scipy.sparse.csr_matrix + + Returns + ------- + scipy.sparse.csr_matrix + """ + degrees = np.squeeze(np.array(np.sqrt(1.0 / fau_stats.sum(adj, axis=0)))) + if adj.shape[0] != len(degrees): + raise ValueError("len(degrees) must equal number of rows of adj") + res_data = _csr_bilateral_diag_scale_helper(adj, degrees) + return csr_matrix((res_data, adj.indices, adj.indptr), shape=adj.shape) + + +def _transform_a_spectral(a: spmatrix) -> spmatrix: + if not isspmatrix_csr(a): + a = a.tocsr() + if not a.nnz: + return a + + return symmetric_normalize_csr(a) + + +def _transform_a_cosine(a: spmatrix) -> spmatrix: + return cosine_similarity(a, dense_output=False) diff --git a/tests/graph/test_spatial_neighbors.py b/tests/graph/test_spatial_neighbors.py index db638132b..150d211bb 100644 --- a/tests/graph/test_spatial_neighbors.py +++ b/tests/graph/test_spatial_neighbors.py @@ -10,9 +10,18 @@ from shapely import Point from spatialdata.datasets import blobs +from squidpy._constants._constants import Transform from squidpy._constants._pkg_constants import Key -from squidpy.gr import mask_graph, spatial_neighbors -from squidpy.gr._build import _build_connectivity +from squidpy.gr import mask_graph, spatial_neighbors, spatial_neighbors_from_builder +from squidpy.gr.neighbors import ( + DelaunayBuilder, + GridBuilder, + KNNBuilder, + RadiusBuilder, +) +from squidpy.gr.neighbors import ( + KNNBuilder as PublicKNNBuilder, +) class TestSpatialNeighbors: @@ -46,6 +55,22 @@ def _adata_concat(adata1, adata2): ) return adata_concat, batch1, batch2 + @staticmethod + def _assert_library_key_block_diagonal(adata, **neighbor_kwargs): + adata2 = adata.copy() + adata_concat, batch1, batch2 = TestSpatialNeighbors._adata_concat(adata, adata2) + spatial_neighbors(adata2, **neighbor_kwargs) + spatial_neighbors(adata_concat, library_key="library_id", **neighbor_kwargs) + np.testing.assert_array_equal( + adata_concat[adata_concat.obs["library_id"] == batch1].obsp[Key.obsp.spatial_conn()].toarray(), + adata.obsp[Key.obsp.spatial_conn()].toarray(), + ) + np.testing.assert_array_equal( + adata_concat[adata_concat.obs["library_id"] == batch2].obsp[Key.obsp.spatial_conn()].toarray(), + adata2.obsp[Key.obsp.spatial_conn()].toarray(), + ) + return adata_concat + # TODO: add edge cases # TODO(giovp): test with reshuffling @pytest.mark.parametrize(("n_rings", "n_neigh", "sum_dist"), [(1, 6, 0), (2, 18, 30), (3, 36, 84)]) @@ -65,20 +90,8 @@ def test_spatial_neighbors_visium( if n_rings > 1: assert visium_adata.obsp[Key.obsp.spatial_dist()][0].sum() == sum_dist - # test for library_key - visium_adata2 = visium_adata.copy() - adata_concat, batch1, batch2 = TestSpatialNeighbors._adata_concat(visium_adata, visium_adata2) - spatial_neighbors(visium_adata2, n_rings=n_rings) - spatial_neighbors(adata_concat, library_key="library_id", n_rings=n_rings) + adata_concat = self._assert_library_key_block_diagonal(visium_adata, n_rings=n_rings) assert adata_concat.obsp[Key.obsp.spatial_conn()][0].sum() == n_neigh - np.testing.assert_array_equal( - adata_concat[adata_concat.obs["library_id"] == batch1].obsp[Key.obsp.spatial_conn()].toarray(), - visium_adata.obsp[Key.obsp.spatial_conn()].toarray(), - ) - np.testing.assert_array_equal( - adata_concat[adata_concat.obs["library_id"] == batch2].obsp[Key.obsp.spatial_conn()].toarray(), - visium_adata2.obsp[Key.obsp.spatial_conn()].toarray(), - ) @pytest.mark.parametrize(("n_rings", "n_neigh", "sum_neigh"), [(1, 4, 4), (2, 4, 12), (3, 4, 24)]) def test_spatial_neighbors_squaregrid(self, adata_squaregrid: AnnData, n_rings: int, n_neigh: int, sum_neigh: int): @@ -90,26 +103,13 @@ def test_spatial_neighbors_squaregrid(self, adata_squaregrid: AnnData, n_rings: assert np.diff(adata.obsp[Key.obsp.spatial_conn()].indptr).max() == sum_neigh assert adata.uns[Key.uns.spatial_neighs()]["distances_key"] == Key.obsp.spatial_dist() - # test for library_key - adata2 = adata.copy() - adata_concat, batch1, batch2 = TestSpatialNeighbors._adata_concat(adata, adata2) - spatial_neighbors(adata2, n_neighs=n_neigh, n_rings=n_rings, coord_type="grid") - spatial_neighbors( - adata_concat, - library_key="library_id", + adata_concat = self._assert_library_key_block_diagonal( + adata, n_neighs=n_neigh, n_rings=n_rings, coord_type="grid", ) assert np.diff(adata_concat.obsp[Key.obsp.spatial_conn()].indptr).max() == sum_neigh - np.testing.assert_array_equal( - adata_concat[adata_concat.obs["library_id"] == batch1].obsp[Key.obsp.spatial_conn()].toarray(), - adata.obsp[Key.obsp.spatial_conn()].toarray(), - ) - np.testing.assert_array_equal( - adata_concat[adata_concat.obs["library_id"] == batch2].obsp[Key.obsp.spatial_conn()].toarray(), - adata2.obsp[Key.obsp.spatial_conn()].toarray(), - ) @pytest.mark.parametrize("type_rings", [("grid", 1), ("grid", 6), ("generic", 1)]) @pytest.mark.parametrize("set_diag", [False, True]) @@ -161,20 +161,7 @@ def test_spatial_neighbors_non_visium(self, non_visium_adata: AnnData): np.testing.assert_array_equal(spatial_graph, self._gt_dgraph) np.testing.assert_allclose(spatial_dist, self._gt_ddist) - # test for library_key - non_visium_adata2 = non_visium_adata.copy() - adata_concat, batch1, batch2 = TestSpatialNeighbors._adata_concat(non_visium_adata, non_visium_adata2) - spatial_neighbors(adata_concat, library_key="library_id", delaunay=True, coord_type=None) - spatial_neighbors(non_visium_adata2, delaunay=True, coord_type=None) - - np.testing.assert_array_equal( - adata_concat[adata_concat.obs["library_id"] == batch1].obsp[Key.obsp.spatial_conn()].toarray(), - non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray(), - ) - np.testing.assert_array_equal( - adata_concat[adata_concat.obs["library_id"] == batch2].obsp[Key.obsp.spatial_conn()].toarray(), - non_visium_adata2.obsp[Key.obsp.spatial_conn()].toarray(), - ) + self._assert_library_key_block_diagonal(non_visium_adata, delaunay=True, coord_type=None) @pytest.mark.parametrize("set_diag", [False, True]) @pytest.mark.parametrize("radius", [(0, np.inf), (2.0, 4.0), (-42, -420), (100, 200)]) @@ -213,6 +200,78 @@ def test_copy(self, non_visium_adata: AnnData): np.testing.assert_allclose(result.distances.toarray(), self._gt_ddist) np.testing.assert_allclose(result.connectivities.toarray(), self._gt_dgraph) + def test_builder_module_export(self): + assert PublicKNNBuilder is KNNBuilder + + @pytest.mark.parametrize( + ("legacy_kwargs", "builder"), + [ + ({"n_neighs": 3, "coord_type": "generic"}, KNNBuilder(n_neighs=3)), + ({"radius": 5.0, "coord_type": "generic"}, RadiusBuilder(radius=5.0)), + ({"delaunay": True, "coord_type": "generic"}, DelaunayBuilder()), + ], + ids=["knn", "radius", "delaunay"], + ) + def test_generic_builder_matches_legacy(self, non_visium_adata: AnnData, legacy_kwargs: dict, builder: object): + legacy = spatial_neighbors(non_visium_adata, **legacy_kwargs, copy=True) + result = spatial_neighbors_from_builder(non_visium_adata, builder=builder, copy=True) + + np.testing.assert_array_equal(legacy.connectivities.toarray(), result.connectivities.toarray()) + np.testing.assert_allclose(legacy.distances.toarray(), result.distances.toarray()) + + @pytest.mark.parametrize( + ("legacy_kwargs", "builder"), + [ + ({"n_neighs": 4, "n_rings": 2, "coord_type": "grid"}, GridBuilder(n_neighs=4, n_rings=2)), + ({"n_neighs": 6, "n_rings": 1, "coord_type": "grid"}, GridBuilder(n_neighs=6, n_rings=1)), + ], + ids=["4neighs_2rings", "6neighs_1ring"], + ) + def test_grid_builder_matches_legacy(self, adata_squaregrid: AnnData, legacy_kwargs: dict, builder: object): + legacy = spatial_neighbors(adata_squaregrid, **legacy_kwargs, copy=True) + result = spatial_neighbors_from_builder(adata_squaregrid, builder=builder, copy=True) + + np.testing.assert_array_equal(legacy.connectivities.toarray(), result.connectivities.toarray()) + np.testing.assert_allclose(legacy.distances.toarray(), result.distances.toarray()) + + def test_builder_explicit_entry_point(self, non_visium_adata: AnnData): + builder = KNNBuilder(n_neighs=3, transform=Transform.NONE) + + baseline = spatial_neighbors_from_builder(non_visium_adata, builder=builder, copy=True) + matched = spatial_neighbors_from_builder(non_visium_adata, builder=builder, copy=True) + + np.testing.assert_array_equal(baseline.connectivities.toarray(), matched.connectivities.toarray()) + np.testing.assert_allclose(baseline.distances.toarray(), matched.distances.toarray()) + + def test_grid_mode_ignores_radius(self, adata_squaregrid: AnnData): + default = spatial_neighbors(adata_squaregrid, coord_type="grid", n_neighs=4, n_rings=2, copy=True) + ignored = spatial_neighbors( + adata_squaregrid, + coord_type="grid", + n_neighs=4, + n_rings=2, + radius=(0.1, 0.2), + copy=True, + ) + + np.testing.assert_array_equal(default.connectivities.toarray(), ignored.connectivities.toarray()) + np.testing.assert_allclose(default.distances.toarray(), ignored.distances.toarray()) + + def test_delaunay_mode_ignores_scalar_radius(self, non_visium_adata: AnnData): + default = spatial_neighbors(non_visium_adata, coord_type="generic", delaunay=True, copy=True) + ignored = spatial_neighbors(non_visium_adata, coord_type="generic", delaunay=True, radius=5.0, copy=True) + + np.testing.assert_array_equal(default.connectivities.toarray(), ignored.connectivities.toarray()) + np.testing.assert_allclose(default.distances.toarray(), ignored.distances.toarray()) + + def test_delaunay_mode_warns_on_n_neighs(self, non_visium_adata: AnnData): + with pytest.warns(FutureWarning, match=r"Parameter `n_neighs` is ignored when `delaunay=True`"): + spatial_neighbors(non_visium_adata, coord_type="generic", delaunay=True, n_neighs=3, copy=True) + + def test_radius_mode_warns_on_n_neighs(self, non_visium_adata: AnnData): + with pytest.warns(FutureWarning, match=r"Parameter `n_neighs` is ignored when `radius` is set"): + spatial_neighbors(non_visium_adata, coord_type="generic", radius=5.0, n_neighs=3, copy=True) + @pytest.mark.parametrize("percentile", [99.0, 95.0]) def test_percentile_filtering(self, adata_hne: AnnData, percentile: float, coord_type="generic"): result = spatial_neighbors(adata_hne, coord_type=coord_type, copy=True) @@ -222,7 +281,7 @@ def test_percentile_filtering(self, adata_hne: AnnData, percentile: float, coord assert not ((result.connectivities != result_filtered.connectivities).nnz == 0) assert result.distances.max() > result_filtered.distances.max() - Adj, Dst = _build_connectivity(adata_hne.obsm["spatial"], n_neighs=6, return_distance=True, set_diag=False) + Adj, Dst = KNNBuilder(n_neighs=6, set_diag=False).build_graph(adata_hne.obsm["spatial"]) threshold = np.percentile(Dst.data, percentile) Adj[Dst > threshold] = 0.0 Dst[Dst > threshold] = 0.0