diff --git a/pyproject.toml b/pyproject.toml index 06e9dfc5..8f2bf00f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ dependencies = [ # due to https://github.com/scikit-image/scikit-image/issues/6850 breaks rescale ufunc "scikit-learn>=0.24", "spatialdata>=0.7.1", - "spatialdata-plot", + "spatialdata-plot>=0.3.3", "statsmodels>=0.12", # https://github.com/scverse/squidpy/issues/526 "tifffile!=2022.4.22", diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 3c5c8639..1816609f 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -4,11 +4,12 @@ import functools import inspect +import os import warnings from collections.abc import Callable, Generator, Hashable, Iterable, Sequence from contextlib import contextmanager from enum import Enum -from multiprocessing import Manager, cpu_count +from multiprocessing import Manager from queue import Queue from threading import Thread from typing import TYPE_CHECKING, Any, Literal @@ -45,6 +46,19 @@ def wrapper(*args: Any, **kw: Any) -> Any: NDArrayA = NDArray[Any] +def cpu_count() -> int: + """Number of CPUs available to this process. + + Uses :func:`os.sched_getaffinity` to respect cgroup limits set by + SLURM, Docker, or ``taskset``. Falls back to :func:`os.cpu_count` + on platforms where affinity queries are unavailable (e.g. macOS). + """ + try: + return len(os.sched_getaffinity(0)) + except (AttributeError, OSError): + return os.cpu_count() or 1 + + class SigQueue(Queue["Signal"] if TYPE_CHECKING else Queue): # type: ignore[misc] """Signalling queue.""" diff --git a/src/squidpy/experimental/__init__.py b/src/squidpy/experimental/__init__.py index 435cd009..5f4c695a 100644 --- a/src/squidpy/experimental/__init__.py +++ b/src/squidpy/experimental/__init__.py @@ -6,6 +6,6 @@ from __future__ import annotations -from . import im, pl +from . import im, pl, tl -__all__ = ["im", "pl"] +__all__ = ["im", "pl", "tl"] diff --git a/src/squidpy/experimental/im/_tiling.py b/src/squidpy/experimental/im/_tiling.py new file mode 100644 index 00000000..9460356f --- /dev/null +++ b/src/squidpy/experimental/im/_tiling.py @@ -0,0 +1,427 @@ +"""Cell-aware tiling for large images. + +Splits a label image into overlapping tiles such that every cell is fully +contained in exactly one tile. Cells are assigned to tiles by centroid: +the tile whose non-overlapping base region contains the centroid owns the +cell. Non-owned cells are zeroed out in each tile's mask so that +downstream processing never double-counts. + +All functions accept pre-computed centroid dicts and image shapes — they +never materialize the full image or label array. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Literal + +import numpy as np +import xarray as xr +from skimage.measure import regionprops + + +@dataclass(frozen=True) +class CellInfo: + """Centroid and bounding box for a single label.""" + + label: int + centroid_y: float + centroid_x: float + bbox_h: int # height of bounding box + bbox_w: int # width of bounding box + + +@dataclass(frozen=True) +class TileSpec: + """Specification for a single tile. + + Attributes + ---------- + base + The non-overlapping region ``(y0, x0, y1, x1)`` used for centroid + ownership. Tiles partition the image into a grid of base regions. + crop + The extended region ``(y0, x0, y1, x1)`` that includes the overlap + margin. This is the actual slice extracted from the image/labels. + owned_ids + Label IDs whose centroid falls inside ``base``. Only these labels + are kept in the tile's mask; all others are zeroed out. + """ + + base: tuple[int, int, int, int] + crop: tuple[int, int, int, int] + owned_ids: frozenset[int] = field(default_factory=frozenset) + + +# --------------------------------------------------------------------------- +# Centroid computation +# --------------------------------------------------------------------------- + + +def compute_cell_info(labels: np.ndarray) -> dict[int, CellInfo]: + """Compute centroid and bounding-box size for every label from a numpy array. + + Parameters + ---------- + labels + 2-D integer label image where 0 is background. + + Returns + ------- + Mapping from label ID to :class:`CellInfo`. + """ + props = regionprops(labels) + info: dict[int, CellInfo] = {} + for p in props: + min_row, min_col, max_row, max_col = p.bbox + info[p.label] = CellInfo( + label=p.label, + centroid_y=p.centroid[0], + centroid_x=p.centroid[1], + bbox_h=max_row - min_row, + bbox_w=max_col - min_col, + ) + return info + + +def compute_cell_info_multiscale( + labels_node: xr.DataTree, + target_scale: str = "scale0", +) -> dict[int, CellInfo]: + """Compute centroids using the coarsest scale of a multiscale label pyramid. + + Reads only the smallest resolution, then scales coordinates to *target_scale*. + """ + available = list(labels_node.keys()) + if not available: + return {} + + def _scale_idx(k: str) -> int: + num = "".join(c for c in k if c.isdigit()) + return int(num) if num else 0 + + coarsest = max(available, key=_scale_idx) + coarse_da = labels_node[coarsest].ds["image"] + coarse_labels = np.asarray(coarse_da.values).squeeze() + + if coarse_labels.ndim != 2: + raise ValueError(f"Expected 2-D labels at scale {coarsest}, got shape {coarse_labels.shape}") + + target_da = labels_node[target_scale].ds["image"] + target_h, target_w = target_da.sizes.get("y", target_da.shape[-2]), target_da.sizes.get("x", target_da.shape[-1]) + coarse_h, coarse_w = coarse_labels.shape + scale_y = target_h / coarse_h + scale_x = target_w / coarse_w + + props = regionprops(coarse_labels) + return { + p.label: CellInfo( + label=p.label, + centroid_y=p.centroid[0] * scale_y, + centroid_x=p.centroid[1] * scale_x, + bbox_h=int(np.ceil((p.bbox[2] - p.bbox[0]) * scale_y)), + bbox_w=int(np.ceil((p.bbox[3] - p.bbox[1]) * scale_x)), + ) + for p in props + } + + +def compute_cell_info_tiled( + labels_da: xr.DataArray, + chunk_size: int = 4096, +) -> dict[int, CellInfo]: + """Compute centroids by reading label tiles — never materializes the full array. + + For cells spanning multiple chunks, centroids are computed as + area-weighted means of per-chunk centroids. + + Parameters + ---------- + labels_da + 2-D (y, x) dask-backed xarray DataArray. + chunk_size + Size of chunks to read at a time. + """ + H = int(labels_da.sizes.get("y", labels_da.shape[-2])) + W = int(labels_da.sizes.get("x", labels_da.shape[-1])) + + # Per-label accumulators: [sum_y*area, sum_x*area, total_area, min_y, max_y, min_x, max_x] + stats: dict[int, list[float]] = {} + + for y0 in range(0, H, chunk_size): + y1 = min(y0 + chunk_size, H) + for x0 in range(0, W, chunk_size): + x1 = min(x0 + chunk_size, W) + chunk = labels_da.isel(y=slice(y0, y1), x=slice(x0, x1)).values + if chunk.ndim > 2: + chunk = chunk.squeeze() + + for p in regionprops(chunk): + lid = p.label + cy_global = p.centroid[0] + y0 + cx_global = p.centroid[1] + x0 + area = p.area + min_row = p.bbox[0] + y0 + max_row = p.bbox[2] + y0 + min_col = p.bbox[1] + x0 + max_col = p.bbox[3] + x0 + + if lid not in stats: + stats[lid] = [cy_global * area, cx_global * area, area, min_row, max_row, min_col, max_col] + else: + s = stats[lid] + s[0] += cy_global * area + s[1] += cx_global * area + s[2] += area + s[3] = min(s[3], min_row) + s[4] = max(s[4], max_row) + s[5] = min(s[5], min_col) + s[6] = max(s[6], max_col) + + result: dict[int, CellInfo] = {} + for lid, s in stats.items(): + if lid == 0: + continue + result[lid] = CellInfo( + label=lid, + centroid_y=s[0] / s[2], + centroid_x=s[1] / s[2], + bbox_h=int(s[4] - s[3]), + bbox_w=int(s[6] - s[5]), + ) + return result + + +# --------------------------------------------------------------------------- +# Tile spec building +# --------------------------------------------------------------------------- + + +def _auto_margin(cell_info: dict[int, CellInfo]) -> int: + """Compute the minimum margin that covers the largest cell's half-extent.""" + if not cell_info: + return 0 + max_extent = max(max(c.bbox_h, c.bbox_w) for c in cell_info.values()) + # Centroid can be at most half a bbox away from the cell's edge. + # Add 1 pixel for safety (rounding / off-by-one). + return int(np.ceil(max_extent / 2)) + 1 + + +def build_tile_specs( + image_shape: tuple[int, int], + cell_info: dict[int, CellInfo], + tile_size: int = 2048, + overlap_margin: int | Literal["auto"] = "auto", +) -> list[TileSpec]: + """Build tile specifications from pre-computed centroids. + + No pixel data is needed — only the image dimensions and centroid dict. + + Parameters + ---------- + image_shape + ``(H, W)`` of the full-resolution image/labels. + cell_info + Pre-computed centroids from :func:`compute_cell_info`, + :func:`compute_cell_info_multiscale`, or :func:`compute_cell_info_tiled`. + tile_size + Side length of the non-overlapping base grid cells. + overlap_margin + Pixel margin added around each base region. ``"auto"`` computes the + minimum margin from the largest cell's bounding box. + + Returns + ------- + List of :class:`TileSpec`, one per grid cell that owns at least one + label. Empty tiles (no cells) are omitted. + """ + H, W = image_shape + if tile_size <= 0: + raise ValueError(f"tile_size must be positive, got {tile_size}") + + if isinstance(overlap_margin, str) and overlap_margin == "auto": + margin = _auto_margin(cell_info) + else: + margin = int(overlap_margin) + if margin < 0: + raise ValueError(f"overlap_margin must be non-negative, got {margin}") + + cell_to_tile: dict[int, tuple[int, int]] = {} + for lid, ci in cell_info.items(): + tile_row = min(int(ci.centroid_y) // tile_size, (H - 1) // tile_size) + tile_col = min(int(ci.centroid_x) // tile_size, (W - 1) // tile_size) + cell_to_tile[lid] = (tile_row, tile_col) + + tile_to_cells: dict[tuple[int, int], set[int]] = {} + for lid, key in cell_to_tile.items(): + tile_to_cells.setdefault(key, set()).add(lid) + + n_rows = (H + tile_size - 1) // tile_size + n_cols = (W + tile_size - 1) // tile_size + + specs: list[TileSpec] = [] + for row in range(n_rows): + for col in range(n_cols): + owned = tile_to_cells.get((row, col), set()) + if not owned: + continue + + by0 = row * tile_size + bx0 = col * tile_size + by1 = min(by0 + tile_size, H) + bx1 = min(bx0 + tile_size, W) + + cy0 = max(by0 - margin, 0) + cx0 = max(bx0 - margin, 0) + cy1 = min(by1 + margin, H) + cx1 = min(bx1 + margin, W) + + specs.append( + TileSpec( + base=(by0, bx0, by1, bx1), + crop=(cy0, cx0, cy1, cx1), + owned_ids=frozenset(owned), + ) + ) + + return specs + + +# --------------------------------------------------------------------------- +# Tile extraction +# --------------------------------------------------------------------------- + + +def extract_tile( + image: np.ndarray, + labels: np.ndarray, + spec: TileSpec, +) -> tuple[np.ndarray, np.ndarray]: + """Extract a tile from numpy arrays, zeroing out non-owned cells. + + Parameters + ---------- + image + ``(C, H, W)`` numpy array. + labels + ``(H, W)`` numpy label array. + spec + Tile specification. + + Returns + ------- + tile_image, tile_labels + """ + cy0, cx0, cy1, cx1 = spec.crop + tile_image = image[:, cy0:cy1, cx0:cx1] + tile_labels = labels[cy0:cy1, cx0:cx1].copy() + _zero_non_owned(tile_labels, spec.owned_ids) + return tile_image, tile_labels + + +def extract_tile_lazy( + image_da: xr.DataArray, + labels_da: xr.DataArray, + spec: TileSpec, +) -> tuple[np.ndarray, np.ndarray]: + """Extract a tile from dask-backed xarray arrays. + + Materializes only the tile's crop region (~2k×2k), not the full image. + + Parameters + ---------- + image_da + ``(c, y, x)`` dask-backed DataArray. + labels_da + ``(y, x)`` dask-backed DataArray. + spec + Tile specification. + + Returns + ------- + tile_image + ``(C, crop_h, crop_w)`` numpy array. + tile_labels + ``(crop_h, crop_w)`` numpy array with non-owned cells zeroed. + """ + cy0, cx0, cy1, cx1 = spec.crop + tile_image = image_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values + tile_labels = labels_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values.copy() + if tile_labels.ndim > 2: + tile_labels = tile_labels.squeeze() + _zero_non_owned(tile_labels, spec.owned_ids) + return tile_image, tile_labels + + +def extract_labels_tile_lazy( + labels_da: xr.DataArray, + spec: TileSpec, +) -> np.ndarray: + """Extract a labels-only tile from a dask-backed DataArray. + + Like :func:`extract_tile_lazy` but skips the image entirely. + Materializes only the crop region. + + Parameters + ---------- + labels_da + ``(y, x)`` dask-backed DataArray. + spec + Tile specification. + + Returns + ------- + ``(crop_h, crop_w)`` numpy array with non-owned cells zeroed. + """ + cy0, cx0, cy1, cx1 = spec.crop + tile_labels = labels_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values.copy() + if tile_labels.ndim > 2: + tile_labels = tile_labels.squeeze() + _zero_non_owned(tile_labels, spec.owned_ids) + return tile_labels + + +def _zero_non_owned(tile_labels: np.ndarray, owned_ids: frozenset[int]) -> None: + """Zero out labels not in *owned_ids* (in-place).""" + owned_arr = np.array(list(owned_ids), dtype=tile_labels.dtype) + mask = ~np.isin(tile_labels, owned_arr) & (tile_labels != 0) + tile_labels[mask] = 0 + + +# --------------------------------------------------------------------------- +# Coverage verification +# --------------------------------------------------------------------------- + + +def verify_coverage( + all_label_ids: set[int], + specs: list[TileSpec], +) -> None: + """Assert that tile specs provide full, non-overlapping cell coverage. + + Parameters + ---------- + all_label_ids + Set of all nonzero label IDs expected in the image. + specs + Tile specifications to verify. + + Raises + ------ + ValueError + If any cell is missing or assigned to more than one tile. + """ + owned_union: set[int] = set() + for spec in specs: + overlap = owned_union & spec.owned_ids + if overlap: + raise ValueError(f"Cells {overlap} assigned to multiple tiles") + owned_union |= spec.owned_ids + + missing = all_label_ids - owned_union + if missing: + raise ValueError(f"Cells {missing} not assigned to any tile") + + extra = owned_union - all_label_ids + if extra: + raise ValueError(f"Tile specs reference non-existent labels {extra}") diff --git a/src/squidpy/experimental/pl/__init__.py b/src/squidpy/experimental/pl/__init__.py index cdb8a56d..1a95f7c7 100644 --- a/src/squidpy/experimental/pl/__init__.py +++ b/src/squidpy/experimental/pl/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations from ._qc_image import qc_image +from ._tiling_qc import tiling_qc -__all__ = ["qc_image"] +__all__ = ["qc_image", "tiling_qc"] diff --git a/src/squidpy/experimental/pl/_tiling_qc.py b/src/squidpy/experimental/pl/_tiling_qc.py new file mode 100644 index 00000000..0a4489c4 --- /dev/null +++ b/src/squidpy/experimental/pl/_tiling_qc.py @@ -0,0 +1,79 @@ +"""Diagnostic plot for tiling segmentation QC.""" + +from __future__ import annotations + +import spatialdata as sd + +__all__ = ["tiling_qc"] + + +def tiling_qc( + sdata: sd.SpatialData, + labels_key: str, + qc_key: str | None = None, + score_col: str = "nhood_outlier_fraction", + cmap: str = "RdYlGn_r", + figsize: tuple[float, float] | None = None, +) -> None: + """Plot labels coloured by their tiling-artifact score. + + Uses :mod:`spatialdata_plot` to render the label element coloured + by the chosen QC score from the linked table. If tile-boundary + artifacts are present the tile grid emerges as lines of + high-scoring cells. + + Parameters + ---------- + sdata + SpatialData object (must contain the QC table). + labels_key + Key in ``sdata.labels`` with the segmentation mask. + qc_key + Key in ``sdata.tables`` with the QC AnnData. Defaults to + ``"{labels_key}_qc"``. + score_col + Which ``.obs`` column to colour by. One of + ``"nhood_outlier_fraction"``, ``"smoothed_cut_score"``, + ``"cut_score"``, ``"max_straight_edge_ratio"``, + ``"cardinal_alignment_score"``. + cmap + Matplotlib colormap name. + figsize + Figure size passed to :meth:`spatialdata.SpatialData.pl.show`. + """ + table_key = qc_key if qc_key is not None else f"{labels_key}_qc" + if table_key not in sdata.tables: + raise ValueError( + f"QC table '{table_key}' not found in sdata.tables. " + f"Run calculate_tiling_qc(sdata, labels_key='{labels_key}') first." + ) + + adata = sdata.tables[table_key] + if score_col not in adata.obs.columns: + raise ValueError( + f"Score column '{score_col}' not found in .obs. " + f"Available: {[c for c in adata.obs.columns if c not in ('region', 'label_id')]}" + ) + + import spatialdata_plot # noqa: F401 - registers accessor + + _TITLES = { + "nhood_outlier_fraction": "Neighborhood outlier fraction", + "smoothed_cut_score": "Smoothed cut score", + "cut_score": "Cut score", + "is_outlier": "Outlier flag", + "max_straight_edge_ratio": "Max straight edge ratio", + "cardinal_alignment_score": "Cardinal alignment score", + } + + show_kwargs: dict[str, object] = {"title": _TITLES.get(score_col, score_col)} + if figsize is not None: + show_kwargs["figsize"] = figsize + + sdata.pl.render_labels( + element=labels_key, + color=score_col, + table_name=table_key, + cmap=cmap, + colorbar=True, + ).pl.show(**show_kwargs) diff --git a/src/squidpy/experimental/tl/__init__.py b/src/squidpy/experimental/tl/__init__.py new file mode 100644 index 00000000..e4e52977 --- /dev/null +++ b/src/squidpy/experimental/tl/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from ._tiling_qc import calculate_tiling_qc + +__all__ = ["calculate_tiling_qc"] diff --git a/src/squidpy/experimental/tl/_tiling_qc.py b/src/squidpy/experimental/tl/_tiling_qc.py new file mode 100644 index 00000000..10dedcc4 --- /dev/null +++ b/src/squidpy/experimental/tl/_tiling_qc.py @@ -0,0 +1,657 @@ +"""QC metrics for detecting tile-boundary segmentation artifacts. + +Cells cut by tile borders during segmentation have characteristic +straight edges that natural cell boundaries never produce. This module +computes per-cell metrics that quantify this artifact: + +- **max_straight_edge_ratio**: length of the longest straight contour + segment normalised by the cell's equivalent diameter. +- **cardinal_alignment_score**: how closely that segment aligns with + 0° or 90° (axis-aligned tile borders). +- **cut_score**: product of the two, combining evidence from shape and + orientation. +- **smoothed_cut_score**: cut_score multiplied by the mean cut_score of + k=10 nearest spatial neighbors - amplifies boundary cells while + suppressing isolated high-scorers. +- **is_outlier**: boolean flag gated on per-cell cut_score and/or + spatially smoothed score exceeding their respective MAD thresholds. +- **nhood_outlier_fraction**: fraction of k=10 nearest neighbors that are + smoothed-score outliers (MAD-based). Bounded [0, 1]; high values + precisely trace the FOV tile grid. + +All heavy computation is done per-tile via the tiling infrastructure +in :mod:`squidpy.experimental.im._tiling`, so this scales to +100k x 100k images without materialising the full array. +""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Literal + +import anndata as ad +import dask +import numpy as np +import pandas as pd +import spatialdata as sd +import xarray as xr +from dask.diagnostics import ProgressBar +from numba import njit +from skimage.measure import find_contours, regionprops +from sklearn.neighbors import BallTree +from spatialdata._logging import logger as logg +from spatialdata.models import TableModel + +if TYPE_CHECKING: + from dask.distributed import Client + +from squidpy._utils import cpu_count +from squidpy.experimental.im._tiling import ( + build_tile_specs, + compute_cell_info, + compute_cell_info_multiscale, + compute_cell_info_tiled, + extract_labels_tile_lazy, +) + +__all__ = ["calculate_tiling_qc"] + +# Minimum cell area in pixels - smaller cells produce noisy contours +_MIN_CELL_AREA = 20 + +# Default perpendicular distance tolerance for collinearity (pixels). +# Points within this distance of the start→end line are considered +# part of the same straight segment. 0.75 px works well for +# sub-pixel contours from marching squares. +_DEFAULT_DISTANCE_TOL = 0.75 + +# Maximum contour points to analyse. Longer contours are resampled +# to this length via equidistant arc-length interpolation to bound +# the O(n²) two-pointer scan. +_MAX_CONTOUR_POINTS = 500 + +_TILE_SCORE_COLUMNS = ["max_straight_edge_ratio", "cardinal_alignment_score", "cut_score"] +_POST_SCORE_COLUMNS = ["smoothed_cut_score", "is_outlier", "nhood_outlier_fraction"] +_SCORE_COLUMNS = _TILE_SCORE_COLUMNS + _POST_SCORE_COLUMNS +_NAN_TILE_SCORES = dict.fromkeys(_TILE_SCORE_COLUMNS, np.nan) + + +# --------------------------------------------------------------------------- +# Core geometry +# --------------------------------------------------------------------------- + + +@njit(cache=True, nogil=True) +def _collinear_scan( + contour: np.ndarray, + cum_arc: np.ndarray, + total_arc: float, + distance_tol: float, +) -> tuple[float, float]: + """Numba-accelerated two-pointer collinearity scan. + + For each start index, extends the end index as long as all + intermediate points stay within ``distance_tol`` of the + start→end line. Returns ``(best_length, best_angle)``. + """ + n = contour.shape[0] + best_len = 0.0 + best_angle = 0.0 + + for start in range(n - 2): + remaining_arc = total_arc - cum_arc[start] + if remaining_arc <= best_len: + break + + for end in range(start + 2, n): + d0 = contour[end, 0] - contour[start, 0] + d1 = contour[end, 1] - contour[start, 1] + seg_len = math.sqrt(d0 * d0 + d1 * d1) + if seg_len < 1e-12: + continue + + max_perp = 0.0 + for k in range(start + 1, end): + r0 = contour[k, 0] - contour[start, 0] + r1 = contour[k, 1] - contour[start, 1] + perp = abs(d0 * r1 - d1 * r0) / seg_len + if perp > max_perp: + max_perp = perp + if perp > distance_tol: + break + + if max_perp > distance_tol: + break + + if seg_len > best_len: + best_len = seg_len + best_angle = math.atan2(d0, d1) + + return best_len, best_angle + + +def _resample_contour(contour: np.ndarray, max_points: int) -> np.ndarray: + """Resample a contour to at most *max_points* via arc-length interpolation. + + Fully vectorised using :func:`numpy.searchsorted` - no Python + loops. Preserves geometry far better than naive stride-based + subsampling because points are placed equidistantly along the + contour arc. + """ + n = len(contour) + if n <= max_points: + return contour + + diffs = np.diff(contour, axis=0) + seg_lengths = np.sqrt((diffs**2).sum(axis=1)) + cum_arc = np.empty(n, dtype=np.float64) + cum_arc[0] = 0.0 + cum_arc[1:] = np.cumsum(seg_lengths) + total = cum_arc[-1] + + if total < 1e-12: + return contour[:max_points] + + targets = np.linspace(0.0, total, max_points) + + idx = np.searchsorted(cum_arc, targets, side="right") - 1 + idx = np.clip(idx, 0, n - 2) + + seg = cum_arc[idx + 1] - cum_arc[idx] + safe_seg = np.where(seg < 1e-12, 1.0, seg) + frac = np.where(seg < 1e-12, 0.0, (targets - cum_arc[idx]) / safe_seg) + + return contour[idx] + frac[:, np.newaxis] * (contour[idx + 1] - contour[idx]) + + +def _longest_collinear_segment( + contour: np.ndarray, + distance_tol: float = _DEFAULT_DISTANCE_TOL, +) -> tuple[float, float]: + """Find the longest collinear run of contour points. + + Uses a numba-compiled two-pointer scan with three contour + rotations to handle the closure point. Long contours are + resampled to at most :data:`_MAX_CONTOUR_POINTS` via arc-length + interpolation to bound worst-case runtime. + + Parameters + ---------- + contour + ``(N, 2)`` array of ``(row, col)`` contour coordinates. + distance_tol + Maximum perpendicular distance (pixels) from the start→end + line for a point to be considered part of the straight segment. + + Returns + ------- + run_length + Euclidean length of the longest straight segment (pixels). + run_angle + Angle (radians, ``[-π, π]``) of that segment. + """ + n = len(contour) + if n < 3: + return 0.0, 0.0 + + pts = np.asarray(contour, dtype=np.float64) + pts = _resample_contour(pts, _MAX_CONTOUR_POINTS) + n = len(pts) + + # find_contours returns closed contours (first ≈ last point) + closed = np.sqrt(((pts[0] - pts[-1]) ** 2).sum()) < 1.0 + + # For closed contours, drop the duplicate last point and precompute + # segment lengths once - rotations reuse the same distances. + if closed and n > 6: + core = pts[:-1] + core_diffs = np.diff(core, axis=0) + core_seg_lens = np.sqrt((core_diffs**2).sum(axis=1)) + rotations = [0, len(core) // 3, 2 * len(core) // 3] + else: + core = pts + core_diffs = np.diff(core, axis=0) + core_seg_lens = np.sqrt((core_diffs**2).sum(axis=1)) + rotations = [0] + + best_len = 0.0 + best_angle = 0.0 + + # Scan at multiple rotations so straight segments crossing the + # closure point are not split. + for shift in rotations: + if shift == 0: + rotated = core + sl = core_seg_lens + else: + rotated = np.roll(core, -shift, axis=0) + sl = np.roll(core_seg_lens, -shift) + + cum_arc = np.empty(len(rotated), dtype=np.float64) + cum_arc[0] = 0.0 + cum_arc[1:] = np.cumsum(sl) + + length, angle = _collinear_scan(rotated, cum_arc, cum_arc[-1], distance_tol) + if length > best_len: + best_len = length + best_angle = angle + + return best_len, best_angle + + +def _cardinal_alignment(angle: float) -> float: + """Score how close an angle is to a cardinal direction (0° or 90°). + + Returns a value in ``[0, 1]`` where 1 means perfectly axis-aligned + and 0 means maximally diagonal (45°). + """ + a = abs(angle) % np.pi + dist = min(a, abs(a - np.pi / 2), abs(a - np.pi)) + + # Map [0, π/4] → [1, 0] + return float(1.0 - dist / (np.pi / 4)) + + +def _straight_edge_metrics( + contour: np.ndarray, + cell_area: float, + distance_tol: float = _DEFAULT_DISTANCE_TOL, +) -> tuple[float, float, float]: + """Compute straight-edge metrics for a single cell contour. + + Parameters + ---------- + contour + ``(N, 2)`` contour coordinates from :func:`skimage.measure.find_contours`. + cell_area + Area of the cell in pixels (for normalisation). + distance_tol + Perpendicular distance tolerance for collinearity (pixels). + + Returns + ------- + straight_edge_ratio + Longest collinear segment / equivalent diameter. + cardinal_score + Cardinal alignment of the longest straight segment. + cut_score + Product of the two. + """ + eq_diam = np.sqrt(4 * cell_area / np.pi) + if eq_diam == 0: + return 0.0, 0.0, 0.0 + + run_length, run_angle = _longest_collinear_segment(contour, distance_tol) + straight_ratio = run_length / eq_diam + cardinal = _cardinal_alignment(run_angle) + cut_score = straight_ratio * cardinal + + return float(straight_ratio), float(cardinal), float(cut_score) + + +# --------------------------------------------------------------------------- +# Per-tile scoring +# --------------------------------------------------------------------------- + + +def _score_tile( + tile_labels: np.ndarray, + distance_tol: float = _DEFAULT_DISTANCE_TOL, + min_area: int = _MIN_CELL_AREA, + downsample: int = 1, +) -> pd.DataFrame: + """Compute tiling QC metrics for all cells in a numpy label tile. + + Parameters + ---------- + tile_labels + ``(H, W)`` label array (background = 0, owned cells only). + distance_tol + Perpendicular distance tolerance for collinearity (pixels). + min_area + Cells smaller than this (in pixels at analysis resolution) + are skipped and get NaN values. + downsample + Factor by which to downsample each cell's bounding-box crop + before contour extraction. ``1`` = full resolution, ``2`` = + half, etc. Straight edges are scale-invariant so moderate + downsampling (2–4x) is safe and much faster for large cells. + + Returns + ------- + DataFrame with columns ``max_straight_edge_ratio``, + ``cardinal_alignment_score``, ``cut_score``, indexed by cell label. + """ + regions = regionprops(tile_labels) + if not regions: + return pd.DataFrame(columns=_TILE_SCORE_COLUMNS, dtype=float) + + rows: dict[int, dict[str, float]] = {} + + for region in regions: + lid = region.label + area = region.area + + if area < min_area * (downsample**2): + rows[lid] = dict(_NAN_TILE_SCORES) + continue + + # Pad with 1px of zeros so find_contours can trace cells + # that touch the crop edge (e.g., cells filling their bbox). + min_row, min_col, max_row, max_col = region.bbox + crop = (tile_labels[min_row:max_row, min_col:max_col] == lid).astype(np.float32) + crop = np.pad(crop, 1, mode="constant", constant_values=0) + + if downsample > 1: + crop = crop[::downsample, ::downsample] + + contours = find_contours(crop, 0.5) + if not contours: + rows[lid] = dict(_NAN_TILE_SCORES) + continue + + contour = max(contours, key=len) + analysis_area = area / (downsample**2) if downsample > 1 else area + ser, cas, cs = _straight_edge_metrics(contour, analysis_area, distance_tol) + + rows[lid] = { + "max_straight_edge_ratio": ser, + "cardinal_alignment_score": cas, + "cut_score": cs, + } + + return pd.DataFrame.from_dict(rows, orient="index") + + +# --------------------------------------------------------------------------- +# Centroid computation (shared logic with _feature.py) +# --------------------------------------------------------------------------- + + +def _compute_centroids_for_labels( + sdata: sd.SpatialData, + labels_key: str, + labels_da: xr.DataArray, + scale: str | None, +) -> dict: + """Compute cell centroids using the most efficient strategy available.""" + if isinstance(sdata.labels[labels_key], xr.DataTree): + logg.info("Computing centroids from coarse scale.") + return compute_cell_info_multiscale(sdata.labels[labels_key], target_scale=scale or "scale0") + + n_pixels = labels_da.sizes.get("y", 1) * labels_da.sizes.get("x", 1) + if n_pixels <= 4096 * 4096: + lbl_np = labels_da.values + if lbl_np.ndim > 2: + lbl_np = lbl_np.squeeze() + return compute_cell_info(lbl_np) + + logg.info("Computing centroids in tiled mode (large single-scale labels).") + return compute_cell_info_tiled(labels_da) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +_METHOD_KEY = "tiling_qc" + + +def calculate_tiling_qc( + sdata: sd.SpatialData, + labels_key: str, + scale: str | None = None, + tile_size: int = 2048, + overlap_margin: int | Literal["auto"] = "auto", + distance_tol: float = _DEFAULT_DISTANCE_TOL, + min_area: int = _MIN_CELL_AREA, + downsample: int = 1, + outlier_use_cut: bool = True, + outlier_use_smoothed: bool = True, + nmads_cut: float = 1.5, + nmads_smoothed: float = 3, + n_jobs: int = -1, + client: Client | None = None, + adata_key_added: str | None = None, + inplace: bool = True, +) -> ad.AnnData | None: + """Score cells for tile-boundary segmentation artifacts. + + Computes per-cell metrics that detect artificially straight edges + caused by tiled segmentation. Large images are processed via the + same tiling infrastructure as + :func:`~squidpy.experimental.im.calculate_image_features`. + + Results are stored in a QC table (default + ``sdata.tables["{labels_key}_qc"]``). Scores live in ``.obs``; + the ``.X`` matrix is empty. Algorithm parameters are recorded in + ``.uns["tiling_qc"]``. + + Parameters + ---------- + sdata + SpatialData object. + labels_key + Key in ``sdata.labels`` with segmentation masks. + scale + Scale level for multi-scale labels. + tile_size + Side length of the tiling grid (pixels). + overlap_margin + Overlap around each tile. ``"auto"`` computes the minimum from + the largest cell's bounding box. + distance_tol + Maximum perpendicular distance (pixels) from the fitted line + for a contour point to be considered part of a straight + segment. Default 0.75 px. + min_area + Cells smaller than this (pixels) are skipped (NaN scores). + downsample + Factor by which to downsample each cell's bounding-box crop + before contour extraction. Straightness is scale-invariant, + so ``2``--``4`` is safe and much faster on large cells. + outlier_use_cut + Gate ``is_outlier`` on the per-cell ``cut_score`` exceeding + its own MAD threshold. Requires the cell itself to have a + straight cardinal-aligned edge. + outlier_use_smoothed + Gate ``is_outlier`` on the spatially smoothed score + (``smoothed_cut_score``) exceeding its MAD threshold. + Requires the cell to be in a spatial cluster of high-scorers. + nmads_cut + Number of MADs for the ``cut_score`` outlier gate. + Threshold is ``median + nmads_cut x MAD x 1.4826``. + nmads_smoothed + Number of MADs for the ``smoothed_cut_score`` outlier gate. + Threshold is ``median + nmads_smoothed x MAD x 1.4826``. + n_jobs + Number of threads for tile processing. ``-1`` (default) uses + all available CPUs. Ignored when ``client`` is provided. + client + A :class:`dask.distributed.Client` for distributed execution. + When provided, tile processing is submitted to this client, + ``n_jobs`` is ignored, and progress is reported via the dask + dashboard. Workers must have access to the underlying data + store (e.g. shared filesystem or cloud storage for zarr). + adata_key_added + Key under which to store the result in ``sdata.tables``. + Defaults to ``"{labels_key}_qc"``. + inplace + If ``True``, store result in ``sdata.tables``. Otherwise + return the AnnData directly. + + Returns + ------- + :class:`~anndata.AnnData` when ``inplace=False``, otherwise ``None``. + The AnnData ``.obs`` contains five scores per cell: + + - ``max_straight_edge_ratio``: longest collinear boundary segment / + equivalent diameter. + - ``cardinal_alignment_score``: axis-alignment of that segment + (1 = cardinal, 0 = diagonal). + - ``cut_score``: product of the two. + - ``smoothed_cut_score``: ``cut_score x mean(neighbor cut_scores)`` + over k=10 nearest spatial neighbors. Amplifies cells on FOV + boundaries while suppressing isolated high-scorers. + - ``is_outlier``: boolean, ``True`` when the enabled outlier + gates are satisfied (``cut_score`` and/or ``smoothed_cut_score`` + exceeding their respective MAD thresholds). + - ``nhood_outlier_fraction``: fraction of k=10 nearest neighbors + that are smoothed-score outliers (MAD-based). Bounded [0, 1]; + high values trace the tile grid. + + Notes + ----- + Tile processing is parallelised via :func:`dask.compute`. By + default a threaded scheduler with ``n_jobs`` workers is used. + Pass a :class:`~dask.distributed.Client` to use a distributed + cluster instead. + """ + if labels_key not in sdata.labels: + raise ValueError(f"Labels key '{labels_key}' not found, valid keys: {list(sdata.labels.keys())}") + + labels_node = sdata.labels[labels_key] + if isinstance(labels_node, xr.DataTree): + if scale is None: + raise ValueError("When using multi-scale labels, please specify the scale.") + labels_da = labels_node[scale].ds["image"] + else: + labels_da = labels_node + + cell_info = _compute_centroids_for_labels(sdata, labels_key, labels_da, scale) + if not cell_info: + raise ValueError("No cells found in labels (all zeros).") + + H = int(labels_da.sizes.get("y", labels_da.shape[-2])) + W = int(labels_da.sizes.get("x", labels_da.shape[-1])) + + specs = build_tile_specs((H, W), cell_info, tile_size=tile_size, overlap_margin=overlap_margin) + logg.info( + f"Tiling QC: {len(specs)} tiles ({tile_size}x{tile_size}, margin={overlap_margin}, downsample={downsample}x)." + ) + + @dask.delayed + def _process_one(spec): + tile_lbl = extract_labels_tile_lazy(labels_da, spec) + return _score_tile(tile_lbl, distance_tol=distance_tol, min_area=min_area, downsample=downsample) + + tasks = [_process_one(spec) for spec in specs] + + if client is not None: + if n_jobs != -1: + logg.warning("`n_jobs` is ignored when a `client` is provided. Parallelism is controlled by the client.") + results = dask.compute(*tasks, scheduler=client) + else: + num_workers = cpu_count() if n_jobs == -1 else n_jobs + with ProgressBar(): + results = dask.compute(*tasks, scheduler="threads", num_workers=num_workers) + + tile_dfs = [df for df in results if not df.empty] + + if not tile_dfs: + raise ValueError("No cells scored - labels may be empty or all below min_area.") + + combined = pd.concat(tile_dfs, axis=0).sort_index() + + if combined.index.duplicated().any(): + dups = combined.index[combined.index.duplicated()].unique().tolist() + raise RuntimeError(f"Duplicate cell IDs across tiles - tile ownership may be broken. Duplicates: {dups}") + + # --- Validation --- + if not outlier_use_cut and not outlier_use_smoothed: + raise ValueError("At least one outlier gate must be enabled (outlier_use_cut or outlier_use_smoothed).") + if outlier_use_cut and nmads_cut <= 0: + raise ValueError(f"nmads_cut must be positive, got {nmads_cut}.") + if outlier_use_smoothed and nmads_smoothed <= 0: + raise ValueError(f"nmads_smoothed must be positive, got {nmads_smoothed}.") + + # --- Spatial context post-processing --- + n_cells = len(combined) + k = 10 + + centroid_y = np.array([cell_info[lid].centroid_y for lid in combined.index]) + centroid_x = np.array([cell_info[lid].centroid_x for lid in combined.index]) + centroids = np.column_stack([centroid_y, centroid_x]) + + if n_cells <= 1: + combined["smoothed_cut_score"] = combined["cut_score"] + combined["is_outlier"] = False + combined["nhood_outlier_fraction"] = 0.0 + else: + effective_k = min(k, n_cells - 1) + tree = BallTree(centroids) + _, indices = tree.query(centroids, k=effective_k + 1) # +1 because query includes self + neighbor_idx = indices[:, 1:] + + cut_scores = combined["cut_score"].values.copy() + cut_scores = np.where(np.isnan(cut_scores), 0.0, cut_scores) + neighbor_mean = cut_scores[neighbor_idx].mean(axis=1) + smoothed = cut_scores * neighbor_mean + combined["smoothed_cut_score"] = smoothed + + # Build is_outlier from enabled gates (AND when both active) + is_outlier = np.ones(n_cells, dtype=bool) + + if outlier_use_cut: + median_c = np.median(cut_scores) + mad_c = np.median(np.abs(cut_scores - median_c)) + if mad_c < 1e-12: + is_outlier[:] = False + else: + is_outlier &= cut_scores >= median_c + nmads_cut * mad_c * 1.4826 + + if outlier_use_smoothed: + median_s = np.median(smoothed) + mad_s = np.median(np.abs(smoothed - median_s)) + if mad_s < 1e-12: + is_outlier[:] = False + else: + is_outlier &= smoothed >= median_s + nmads_smoothed * mad_s * 1.4826 + + combined["is_outlier"] = is_outlier + + neighbor_outlier_frac = combined["is_outlier"].values[neighbor_idx].mean(axis=1) + combined["nhood_outlier_fraction"] = neighbor_outlier_frac + + adata = ad.AnnData( + X=np.empty((n_cells, 0), dtype=np.float32), + ) + adata.obs_names = [f"cell_{i}" for i in combined.index] + + adata.obs["region"] = pd.Categorical([labels_key] * n_cells) + adata.obs["label_id"] = combined.index.values + adata.uns["spatialdata_attrs"] = { + "region": labels_key, + "region_key": "region", + "instance_key": "label_id", + } + + # TODO: migrate tiling QC scores to .obsm once spatialdata-plot + # supports rendering labels colored by obsm keys. + # See scverse/spatialdata-plot#587. + for col in combined.columns: + adata.obs[col] = combined[col].values + + adata.obs["centroid_y"] = centroid_y + adata.obs["centroid_x"] = centroid_x + + adata.uns[_METHOD_KEY] = { + "scale": scale, + "tile_size": tile_size, + "overlap_margin": overlap_margin, + "distance_tol": distance_tol, + "min_area": min_area, + "downsample": downsample, + "outlier_use_cut": outlier_use_cut, + "outlier_use_smoothed": outlier_use_smoothed, + "nmads_cut": nmads_cut, + "nmads_smoothed": nmads_smoothed, + "nhood_k": k, + } + + if inplace: + table_key = adata_key_added if adata_key_added is not None else f"{labels_key}_qc" + sdata.tables[table_key] = TableModel.parse(adata) + return None + return adata diff --git a/tests/_images/TilingQCVisual_tiling_qc_cardinal_alignment.png b/tests/_images/TilingQCVisual_tiling_qc_cardinal_alignment.png new file mode 100644 index 00000000..ebd316d2 Binary files /dev/null and b/tests/_images/TilingQCVisual_tiling_qc_cardinal_alignment.png differ diff --git a/tests/_images/TilingQCVisual_tiling_qc_cut_score.png b/tests/_images/TilingQCVisual_tiling_qc_cut_score.png new file mode 100644 index 00000000..0d1cf123 Binary files /dev/null and b/tests/_images/TilingQCVisual_tiling_qc_cut_score.png differ diff --git a/tests/_images/TilingQCVisual_tiling_qc_is_outlier.png b/tests/_images/TilingQCVisual_tiling_qc_is_outlier.png new file mode 100644 index 00000000..d656c992 Binary files /dev/null and b/tests/_images/TilingQCVisual_tiling_qc_is_outlier.png differ diff --git a/tests/_images/TilingQCVisual_tiling_qc_nhood_outlier_fraction.png b/tests/_images/TilingQCVisual_tiling_qc_nhood_outlier_fraction.png new file mode 100644 index 00000000..db6c9193 Binary files /dev/null and b/tests/_images/TilingQCVisual_tiling_qc_nhood_outlier_fraction.png differ diff --git a/tests/_images/TilingQCVisual_tiling_qc_smoothed_cut_score.png b/tests/_images/TilingQCVisual_tiling_qc_smoothed_cut_score.png new file mode 100644 index 00000000..1f91ffd7 Binary files /dev/null and b/tests/_images/TilingQCVisual_tiling_qc_smoothed_cut_score.png differ diff --git a/tests/_images/TilingQCVisual_tiling_qc_straight_edge_ratio.png b/tests/_images/TilingQCVisual_tiling_qc_straight_edge_ratio.png new file mode 100644 index 00000000..1e6b865b Binary files /dev/null and b/tests/_images/TilingQCVisual_tiling_qc_straight_edge_ratio.png differ diff --git a/tests/experimental/conftest.py b/tests/experimental/conftest.py new file mode 100644 index 00000000..3b1e0c5c --- /dev/null +++ b/tests/experimental/conftest.py @@ -0,0 +1,226 @@ +"""Shared fixtures for experimental tests. + +Provides synthetic SpatialData objects for testing segmentation QC metrics. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +import dask.array as da +import numpy as np +import pytest +import xarray as xr +from scipy import ndimage +from skimage.draw import ellipse +from spatialdata import SpatialData +from spatialdata.models import Image2DModel, Labels2DModel + +# --------------------------------------------------------------------------- +# Tile-boundary QC fixture +# --------------------------------------------------------------------------- + +_IMAGE_SIZE = 600 +_TILE_BORDERS = (200, 400) # 3x3 grid on 600 px - borders at 200, 400 +_BORDER_GAP = 2 # pixels zeroed at each tile border +_SEMI_AXIS_RANGE = (5, 10) # semi-axis lengths in pixels +_GRID_STEP = 24 # spacing between cell centers on the grid + + +@dataclass +class TileBoundaryGroundTruth: + """Ground-truth metadata for the tile-boundary fixture.""" + + cut_cell_ids: frozenset[int] = field(default_factory=frozenset) + intact_cell_ids: frozenset[int] = field(default_factory=frozenset) + original_n_cells: int = 0 + tile_borders_y: tuple[int, ...] = _TILE_BORDERS + tile_borders_x: tuple[int, ...] = _TILE_BORDERS + + +def _place_ellipsoids_grid( + shape: tuple[int, int], + semi_range: tuple[int, int], + grid_step: int, + rng: np.random.Generator, +) -> np.ndarray: + """Place non-overlapping ellipsoids on a jittered grid. + + Cell centers are placed on a regular grid with spacing ``grid_step``, + then jittered by a small random offset. Each cell gets random + semi-axes and rotation. The grid guarantees no overlaps as long as + ``grid_step >= 2 * semi_range[1] + margin``, so no collision + checking is needed. + + Returns an ``(H, W)`` int32 label array with IDs 1..N. + """ + H, W = shape + labels = np.zeros(shape, dtype=np.int32) + margin = semi_range[1] + 1 + max_jitter = (grid_step - 2 * semi_range[1]) // 2 + + # Build grid centers + ys = np.arange(margin, H - margin, grid_step) + xs = np.arange(margin, W - margin, grid_step) + + cell_id = 0 + for y0 in ys: + for x0 in xs: + cy = y0 + rng.integers(-max_jitter, max_jitter + 1) + cx = x0 + rng.integers(-max_jitter, max_jitter + 1) + r_radius = rng.integers(semi_range[0], semi_range[1] + 1) + c_radius = rng.integers(semi_range[0], semi_range[1] + 1) + angle = rng.uniform(0, np.pi) + + rr, cc = ellipse(cy, cx, r_radius, c_radius, shape=shape, rotation=angle) + if len(rr) == 0: + continue + + cell_id += 1 + labels[rr, cc] = cell_id + + return labels + + +def _apply_tile_cuts( + labels: np.ndarray, + borders_y: tuple[int, ...], + borders_x: tuple[int, ...], + gap: int, +) -> np.ndarray: + """Zero out pixels along tile borders to simulate segmentation seams. + + For each border coordinate, a stripe of width ``gap`` centred on the + border is erased. + """ + out = labels.copy() + half = gap // 2 + + for by in borders_y: + out[by - half : by - half + gap, :] = 0 + for bx in borders_x: + out[:, bx - half : bx - half + gap] = 0 + + return out + + +def _relabel_and_track( + original: np.ndarray, + cut: np.ndarray, +) -> tuple[np.ndarray, frozenset[int], frozenset[int]]: + """Relabel connected components after cutting and classify fragments. + + Returns + ------- + relabelled + New label array with unique IDs for each fragment. + cut_ids + Fragment IDs that came from an original cell that was split. + intact_ids + Fragment IDs from cells that remained whole. + """ + relabelled, n_fragments = ndimage.label(cut > 0) + + cut_ids: set[int] = set() + intact_ids: set[int] = set() + + orig_to_fragments: dict[int, set[int]] = {} + for frag_id in range(1, n_fragments + 1): + frag_mask = relabelled == frag_id + orig_ids_in_frag = set(np.unique(original[frag_mask])) - {0} + + for oid in orig_ids_in_frag: + orig_to_fragments.setdefault(oid, set()).add(frag_id) + + # Classify: if an original cell maps to >1 fragment, all its fragments are "cut" + for _orig_id, frag_set in orig_to_fragments.items(): + if len(frag_set) > 1: + cut_ids.update(frag_set) + else: + intact_ids.update(frag_set) + + return relabelled, frozenset(cut_ids), frozenset(intact_ids) + + +def make_tile_boundary_sdata() -> tuple[SpatialData, TileBoundaryGroundTruth]: + """Build a 400x400 SpatialData with ellipsoid cells cut by a 3x3 tile grid. + + Returns a tuple of ``(sdata, ground_truth)`` where ``ground_truth`` + contains the sets of cut and intact cell IDs for test assertions. + + The labels are dask-backed to exercise lazy codepaths. + """ + rng = np.random.default_rng(42) + + original_labels = _place_ellipsoids_grid( + shape=(_IMAGE_SIZE, _IMAGE_SIZE), + semi_range=_SEMI_AXIS_RANGE, + grid_step=_GRID_STEP, + rng=rng, + ) + n_original = len(np.unique(original_labels)) - 1 + + cut_labels = _apply_tile_cuts( + original_labels, + borders_y=_TILE_BORDERS, + borders_x=_TILE_BORDERS, + gap=_BORDER_GAP, + ) + + relabelled, cut_ids, intact_ids = _relabel_and_track(original_labels, cut_labels) + + dask_labels = da.from_array(relabelled, chunks=(200, 200)) + labels_xr = xr.DataArray(dask_labels, dims=["y", "x"]) + + image_data = rng.integers(0, 255, (3, _IMAGE_SIZE, _IMAGE_SIZE), dtype=np.uint8) + image_xr = xr.DataArray(image_data, dims=["c", "y", "x"], coords={"c": ["R", "G", "B"]}) + + sdata = SpatialData( + images={"image": Image2DModel.parse(image_xr)}, + labels={"labels": Labels2DModel.parse(labels_xr)}, + ) + + ground_truth = TileBoundaryGroundTruth( + cut_cell_ids=cut_ids, + intact_cell_ids=intact_ids, + original_n_cells=n_original, + ) + + return sdata, ground_truth + + +@pytest.fixture() +def sdata_tile_boundary() -> tuple[SpatialData, TileBoundaryGroundTruth]: + """Fixture wrapper around :func:`make_tile_boundary_sdata`.""" + return make_tile_boundary_sdata() + + +def make_clean_sdata() -> SpatialData: + """Build a SpatialData with natural ellipsoid cells and NO tile cuts. + + This is the negative control: no tiling artifacts exist, so the + spatial post-processing should flag zero outliers. + """ + rng = np.random.default_rng(123) + labels = _place_ellipsoids_grid( + shape=(_IMAGE_SIZE, _IMAGE_SIZE), + semi_range=_SEMI_AXIS_RANGE, + grid_step=_GRID_STEP, + rng=rng, + ) + dask_labels = da.from_array(labels, chunks=(200, 200)) + labels_xr = xr.DataArray(dask_labels, dims=["y", "x"]) + + image_data = rng.integers(0, 255, (3, _IMAGE_SIZE, _IMAGE_SIZE), dtype=np.uint8) + image_xr = xr.DataArray(image_data, dims=["c", "y", "x"], coords={"c": ["R", "G", "B"]}) + + return SpatialData( + images={"image": Image2DModel.parse(image_xr)}, + labels={"labels": Labels2DModel.parse(labels_xr)}, + ) + + +@pytest.fixture() +def sdata_clean() -> SpatialData: + """Fixture wrapper around :func:`make_clean_sdata`.""" + return make_clean_sdata() diff --git a/tests/experimental/test_tiling_qc.py b/tests/experimental/test_tiling_qc.py new file mode 100644 index 00000000..22ba19ee --- /dev/null +++ b/tests/experimental/test_tiling_qc.py @@ -0,0 +1,262 @@ +"""Tests for tiling segmentation QC metrics.""" + +from __future__ import annotations + +import numpy as np +import pytest + +import squidpy as sq +from tests.conftest import PlotTester, PlotTesterMeta + +# --------------------------------------------------------------------------- +# Core behavioural tests +# --------------------------------------------------------------------------- + + +class TestCalculateTilingQC: + """Tests for sq.experimental.tl.calculate_tiling_qc using the tile-boundary fixture.""" + + def test_returns_anndata_with_scores(self, sdata_tile_boundary): + sdata, gt = sdata_tile_boundary + adata = sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + tile_size=200, + inplace=False, + ) + assert adata.n_obs == len(gt.cut_cell_ids) + len(gt.intact_cell_ids) + assert adata.n_vars == 0 + for col in [ + "max_straight_edge_ratio", + "cardinal_alignment_score", + "cut_score", + "smoothed_cut_score", + "is_outlier", + "nhood_outlier_fraction", + ]: + assert col in adata.obs.columns + + def test_cut_cells_score_higher_than_intact(self, sdata_tile_boundary): + sdata, gt = sdata_tile_boundary + adata = sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + tile_size=200, + inplace=False, + ) + obs = adata.obs + cut = obs[obs["label_id"].isin(gt.cut_cell_ids)]["max_straight_edge_ratio"].dropna() + intact = obs[obs["label_id"].isin(gt.intact_cell_ids)]["max_straight_edge_ratio"].dropna() + assert cut.mean() > intact.mean() + + def test_tiled_vs_single_tile(self, sdata_tile_boundary): + """Tiling must not change results — scores should be identical.""" + sdata, _ = sdata_tile_boundary + adata_tiled = sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + tile_size=200, + inplace=False, + ) + adata_single = sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + tile_size=2000, + inplace=False, + ) + df1 = adata_tiled.obs.set_index("label_id").sort_index() + df2 = adata_single.obs.set_index("label_id").sort_index() + + assert set(df1.index) == set(df2.index) + for col in [ + "max_straight_edge_ratio", + "cardinal_alignment_score", + "cut_score", + "smoothed_cut_score", + "nhood_outlier_fraction", + ]: + np.testing.assert_allclose( + df1[col].values, + df2[col].values, + atol=1e-10, + equal_nan=True, + ) + np.testing.assert_array_equal(df1["is_outlier"].values, df2["is_outlier"].values) + + def test_spatial_postprocessing_columns(self, sdata_tile_boundary): + """Spatial post-processing produces correct dtypes and value ranges.""" + sdata, _ = sdata_tile_boundary + adata = sq.experimental.tl.calculate_tiling_qc(sdata, labels_key="labels", tile_size=200, inplace=False) + obs = adata.obs + + # smoothed_cut_score is non-negative (product of non-negatives) + assert (obs["smoothed_cut_score"] >= 0).all() + + # is_outlier is boolean + assert obs["is_outlier"].dtype == bool + + # nhood_outlier_fraction is bounded [0, 1] + assert (obs["nhood_outlier_fraction"] >= 0).all() + assert (obs["nhood_outlier_fraction"] <= 1).all() + + # nhood_k stored in uns + assert adata.uns["tiling_qc"]["nhood_k"] == 10 + + def test_outlier_fraction_consistent_with_is_outlier(self, sdata_tile_boundary): + """nhood_outlier_fraction should be 1.0 only when all k neighbors are outliers.""" + sdata, _ = sdata_tile_boundary + adata = sq.experimental.tl.calculate_tiling_qc(sdata, labels_key="labels", tile_size=200, inplace=False) + obs = adata.obs + # Cells with nhood_outlier_fraction == 0 should exist (most cells are not outliers) + assert (obs["nhood_outlier_fraction"] == 0).any() + # If no cell is an outlier, all fractions must be 0 + if not obs["is_outlier"].any(): + assert (obs["nhood_outlier_fraction"] == 0).all() + + def test_invalid_labels_key(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + with pytest.raises(ValueError, match="not found"): + sq.experimental.tl.calculate_tiling_qc(sdata, labels_key="nonexistent", inplace=False) + + def test_clean_dataset_no_outliers(self, sdata_clean): + """No tiling artifacts → MAD-based outlier detection should flag zero cells.""" + adata = sq.experimental.tl.calculate_tiling_qc(sdata_clean, labels_key="labels", tile_size=200, inplace=False) + obs = adata.obs + assert not obs["is_outlier"].any(), f"Expected 0 outliers on clean data, got {obs['is_outlier'].sum()}" + assert (obs["nhood_outlier_fraction"] == 0).all() + + def test_few_cells_below_k(self): + """Fewer cells than k=10 should not crash.""" + import dask.array as da + import xarray as xr + from spatialdata import SpatialData + from spatialdata.models import Image2DModel, Labels2DModel + + # 3 well-separated circles + labels = np.zeros((100, 100), dtype=np.int32) + for i, (cy, cx) in enumerate([(20, 20), (50, 50), (80, 80)], start=1): + yy, xx = np.ogrid[-cy : 100 - cy, -cx : 100 - cx] + labels[yy**2 + xx**2 <= 64] = i + + sdata = SpatialData( + images={ + "image": Image2DModel.parse(xr.DataArray(np.zeros((3, 100, 100), dtype=np.uint8), dims=["c", "y", "x"])) + }, + labels={ + "labels": Labels2DModel.parse(xr.DataArray(da.from_array(labels, chunks=(100, 100)), dims=["y", "x"])) + }, + ) + adata = sq.experimental.tl.calculate_tiling_qc(sdata, labels_key="labels", tile_size=200, inplace=False) + assert adata.n_obs == 3 + for col in ["smoothed_cut_score", "is_outlier", "nhood_outlier_fraction"]: + assert col in adata.obs.columns + + def test_both_gates_disabled_raises(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + with pytest.raises(ValueError, match="At least one outlier gate"): + sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + tile_size=200, + inplace=False, + outlier_use_cut=False, + outlier_use_smoothed=False, + ) + + def test_invalid_nmads_raises(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + with pytest.raises(ValueError, match="nmads_cut must be positive"): + sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + tile_size=200, + inplace=False, + nmads_cut=0, + ) + with pytest.raises(ValueError, match="nmads_smoothed must be positive"): + sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + tile_size=200, + inplace=False, + nmads_smoothed=-1, + ) + + def test_cut_only_gate(self, sdata_tile_boundary): + """Using only cut_score gate should still produce valid output.""" + sdata, _ = sdata_tile_boundary + adata = sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + tile_size=200, + inplace=False, + outlier_use_cut=True, + outlier_use_smoothed=False, + ) + assert adata.obs["is_outlier"].dtype == bool + assert adata.uns["tiling_qc"]["outlier_use_cut"] is True + assert adata.uns["tiling_qc"]["outlier_use_smoothed"] is False + + def test_smoothed_only_gate(self, sdata_tile_boundary): + """Using only smoothed gate should still produce valid output.""" + sdata, _ = sdata_tile_boundary + adata = sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + tile_size=200, + inplace=False, + outlier_use_cut=False, + outlier_use_smoothed=True, + ) + assert adata.obs["is_outlier"].dtype == bool + + +# --------------------------------------------------------------------------- +# Visual regression tests (PlotTester) +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def sdata_with_qc(sdata_tile_boundary): + """SpatialData with tiling QC already computed.""" + sdata, _ = sdata_tile_boundary + sq.experimental.tl.calculate_tiling_qc(sdata, labels_key="labels", tile_size=200, inplace=True) + return sdata + + +class TestTilingQCVisual(PlotTester, metaclass=PlotTesterMeta): + def test_plot_tiling_qc_cut_score(self, sdata_with_qc): + """Visual: labels coloured by cut_score.""" + sq.experimental.pl.tiling_qc(sdata_with_qc, labels_key="labels", score_col="cut_score") + + def test_plot_tiling_qc_cardinal_alignment(self, sdata_with_qc): + """Visual: labels coloured by cardinal_alignment_score.""" + sq.experimental.pl.tiling_qc( + sdata_with_qc, + labels_key="labels", + score_col="cardinal_alignment_score", + ) + + def test_plot_tiling_qc_straight_edge_ratio(self, sdata_with_qc): + """Visual: labels coloured by max_straight_edge_ratio.""" + sq.experimental.pl.tiling_qc( + sdata_with_qc, + labels_key="labels", + score_col="max_straight_edge_ratio", + ) + + def test_plot_tiling_qc_nhood_outlier_fraction(self, sdata_with_qc): + """Visual: default plot (nhood_outlier_fraction, RdYlGn_r, colorbar).""" + sq.experimental.pl.tiling_qc(sdata_with_qc, labels_key="labels") + + def test_plot_tiling_qc_is_outlier(self, sdata_with_qc): + """Visual: labels coloured by is_outlier (boolean).""" + sq.experimental.pl.tiling_qc(sdata_with_qc, labels_key="labels", score_col="is_outlier") + + def test_plot_tiling_qc_smoothed_cut_score(self, sdata_with_qc): + """Visual: labels coloured by smoothed_cut_score.""" + sq.experimental.pl.tiling_qc( + sdata_with_qc, + labels_key="labels", + score_col="smoothed_cut_score", + )