From 42d74ab78452e72006324dbfcbf4cbef9680c742 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 16 Apr 2026 12:26:43 +0200 Subject: [PATCH 1/6] Add alignment skeleton under sq.experimental.tl Introduces a backend-agnostic alignment API (align_obs, align_images, align_by_landmarks) with STalign and landmark backends, lazy JAX imports, and e2e tests. Co-Authored-By: Claude Opus 4.6 (1M context) --- pyproject.toml | 3 + src/squidpy/experimental/__init__.py | 4 +- src/squidpy/experimental/tl/__init__.py | 56 ++ .../experimental/tl/_align/__init__.py | 21 + src/squidpy/experimental/tl/_align/_api.py | 251 +++++++++ .../tl/_align/_backends/__init__.py | 26 + .../experimental/tl/_align/_backends/_base.py | 38 ++ .../tl/_align/_backends/_landmark.py | 106 ++++ .../tl/_align/_backends/_moscot.py | 45 ++ .../tl/_align/_backends/_stalign.py | 102 ++++ .../tl/_align/_backends/_stalign_core.py | 367 +++++++++++++ .../tl/_align/_backends/_stalign_helpers.py | 187 +++++++ .../tl/_align/_backends/_stalign_tools.py | 261 ++++++++++ src/squidpy/experimental/tl/_align/_io.py | 309 +++++++++++ src/squidpy/experimental/tl/_align/_jax.py | 56 ++ src/squidpy/experimental/tl/_align/_types.py | 120 +++++ .../experimental/tl/_align/_validation.py | 143 +++++ tests/experimental/tl/__init__.py | 0 tests/experimental/tl/test_align_blobs_e2e.py | 216 ++++++++ tests/experimental/tl/test_align_skeleton.py | 489 ++++++++++++++++++ .../tl/test_align_stalign_integration.py | 132 +++++ 21 files changed, 2930 insertions(+), 2 deletions(-) create mode 100644 src/squidpy/experimental/tl/__init__.py create mode 100644 src/squidpy/experimental/tl/_align/__init__.py create mode 100644 src/squidpy/experimental/tl/_align/_api.py create mode 100644 src/squidpy/experimental/tl/_align/_backends/__init__.py create mode 100644 src/squidpy/experimental/tl/_align/_backends/_base.py create mode 100644 src/squidpy/experimental/tl/_align/_backends/_landmark.py create mode 100644 src/squidpy/experimental/tl/_align/_backends/_moscot.py create mode 100644 src/squidpy/experimental/tl/_align/_backends/_stalign.py create mode 100644 src/squidpy/experimental/tl/_align/_backends/_stalign_core.py create mode 100644 src/squidpy/experimental/tl/_align/_backends/_stalign_helpers.py create mode 100644 src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py create mode 100644 src/squidpy/experimental/tl/_align/_io.py create mode 100644 src/squidpy/experimental/tl/_align/_jax.py create mode 100644 src/squidpy/experimental/tl/_align/_types.py create mode 100644 src/squidpy/experimental/tl/_align/_validation.py create mode 100644 tests/experimental/tl/__init__.py create mode 100644 tests/experimental/tl/test_align_blobs_e2e.py create mode 100644 tests/experimental/tl/test_align_skeleton.py create mode 100644 tests/experimental/tl/test_align_stalign_integration.py diff --git a/pyproject.toml b/pyproject.toml index 06e9dfc5a..8b936d749 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,9 @@ optional-dependencies.docs = [ "sphinxcontrib-bibtex>=2.3", "sphinxcontrib-spelling>=7.6.2", ] +optional-dependencies.jax = [ + "jax", +] optional-dependencies.leiden = [ "leidenalg", "spatialleiden>=0.4", diff --git a/src/squidpy/experimental/__init__.py b/src/squidpy/experimental/__init__.py index 435cd0098..5f4c695ab 100644 --- a/src/squidpy/experimental/__init__.py +++ b/src/squidpy/experimental/__init__.py @@ -6,6 +6,6 @@ from __future__ import annotations -from . import im, pl +from . import im, pl, tl -__all__ = ["im", "pl"] +__all__ = ["im", "pl", "tl"] diff --git a/src/squidpy/experimental/tl/__init__.py b/src/squidpy/experimental/tl/__init__.py new file mode 100644 index 000000000..2f6fbd0b5 --- /dev/null +++ b/src/squidpy/experimental/tl/__init__.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from squidpy.experimental.tl._align import align_by_landmarks, align_images, align_obs + +if TYPE_CHECKING: + from squidpy.experimental.tl._align._backends._stalign_tools import ( + STalignConfig, + STalignPreprocessConfig, + STalignPreprocessResult, + STalignRegistrationConfig, + STalignResult, + ) + +__all__ = [ + "STalignConfig", + "STalignPreprocessConfig", + "STalignPreprocessResult", + "STalignRegistrationConfig", + "STalignResult", + "align_by_landmarks", + "align_images", + "align_obs", +] + +_STALIGN_REEXPORTS = frozenset( + { + "STalignConfig", + "STalignPreprocessConfig", + "STalignPreprocessResult", + "STalignRegistrationConfig", + "STalignResult", + } +) + + +def __getattr__(name: str) -> Any: + """Lazy access to the JAX-only STalign config dataclasses. + + Importing :mod:`squidpy.experimental.tl._align._backends._stalign_tools` pulls in + :mod:`jax` at module-load time, so we defer the import until the first + attribute access. This preserves the lazy-import contract pinned by + ``test_optional_deps_not_imported_at_import_time``. + """ + if name in _STALIGN_REEXPORTS: + try: + from squidpy.experimental.tl._align._backends import _stalign_tools as _tools + except ModuleNotFoundError as e: + if e.name == "jax": + raise ImportError( + 'STalign requires the optional dependency `jax`. Install it with `pip install "squidpy[jax]"`.' + ) from e + raise + return getattr(_tools, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/squidpy/experimental/tl/_align/__init__.py b/src/squidpy/experimental/tl/_align/__init__.py new file mode 100644 index 000000000..b070b99c5 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/__init__.py @@ -0,0 +1,21 @@ +"""Alignment skeleton under :mod:`squidpy.experimental.tl`. + +Public surface: + +- :func:`align_obs` — align two ``obs``-level point clouds (cells / spots). +- :func:`align_images` — align two raster images in :class:`spatialdata.SpatialData`. +- :func:`align_by_landmarks` — closed-form fit from user-provided landmarks. + +Optional backends (``stalign``, ``moscot``) and JAX are imported lazily — only +the function call that needs them pulls them in. +""" + +from __future__ import annotations + +from squidpy.experimental.tl._align._api import ( + align_by_landmarks, + align_images, + align_obs, +) + +__all__ = ["align_by_landmarks", "align_images", "align_obs"] diff --git a/src/squidpy/experimental/tl/_align/_api.py b/src/squidpy/experimental/tl/_align/_api.py new file mode 100644 index 000000000..4b08b7c69 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_api.py @@ -0,0 +1,251 @@ +"""Public ``align_*`` orchestrators. + +Each function is intentionally thin: resolve inputs, validate, dispatch to a +backend, write the result back. All branching on argument shape lives in +:mod:`._io`; all backend selection lives in :mod:`._backends`; all validation +of "passed-but-unneeded" combinations lives in :mod:`._validation`. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal + +from squidpy.experimental.tl._align._backends import get_backend +from squidpy.experimental.tl._align._io import ( + apply_affine_to_cs, + materialise_obs, + resolve_element_pair, + resolve_image_pair, + resolve_obs_pair, +) +from squidpy.experimental.tl._align._types import AffineTransform, AlignPair, AlignResult +from squidpy.experimental.tl._align._validation import ( + ALLOWED_FLAVOURS_IMAGES, + ALLOWED_FLAVOURS_OBS, + ALLOWED_OUTPUT_MODES_NONOBS, + ALLOWED_OUTPUT_MODES_OBS, + validate_flavour, + validate_key_added, + validate_landmark_model, + validate_landmarks, + validate_output_mode, + validate_required, +) + +if TYPE_CHECKING: + from anndata import AnnData + from spatialdata import SpatialData + + +def align_obs( + data_ref: AnnData | SpatialData, + data_query: AnnData | SpatialData | None = None, + adata_ref_name: str | None = None, + adata_query_name: str | None = None, + flavour: Literal["stalign", "moscot"] = "stalign", + *, + output_mode: Literal["affine", "obs", "return"] = "affine", + key_added: str | None = None, + device: Literal["cpu", "gpu"] | None = None, + inplace: bool = True, + **flavour_kwargs: Any, +) -> AnnData | SpatialData | AlignResult | None: + """Align two ``obs``-level point clouds (cells / spots). + + Parameters + ---------- + data_ref, data_query + Either both :class:`anndata.AnnData`, both :class:`spatialdata.SpatialData`, + or ``data_ref`` a SpatialData and ``data_query=None`` (in which case + ``adata_ref_name`` and ``adata_query_name`` select two different + tables of the same SpatialData object). + adata_ref_name, adata_query_name + Required only when SpatialData inputs are used. Passing them with + AnnData inputs raises an educational :class:`ValueError`. + flavour + Backend to use. ``'stalign'`` is the default LDDMM-based fit; + ``'moscot'`` is OT-based. + output_mode + How to deliver the result: + + - ``'affine'`` — register the fitted affine on the query element via + :func:`spatialdata.transformations.set_transformation`, so every + element in the query coordinate system inherits the alignment. + Requires the backend to produce an affine transform. + - ``'obs'`` — bake the (possibly non-affine) fit into a new AnnData + whose ``obsm['spatial']`` already lives in the reference coordinate + system; for SpatialData inputs the new table is stored under + ``key_added``. + - ``'return'`` — return the raw :class:`AlignResult`; no writeback. + key_added + Required when ``output_mode='obs'`` and inputs are SpatialData. + Rejected with any other ``output_mode``. + device + ``'cpu'``/``'gpu'`` to force a JAX device, or ``None`` to let JAX + pick the default. Only consulted by JAX-backed flavours. + inplace + If ``True``, mutate the query container; otherwise return a copy. + **flavour_kwargs + Backend-specific knobs forwarded as-is to the chosen backend. + """ + validate_flavour(flavour, allowed=ALLOWED_FLAVOURS_OBS, op="align_obs") + validate_output_mode(output_mode, allowed=ALLOWED_OUTPUT_MODES_OBS, op="align_obs") + validate_key_added(key_added, output_mode) + + pair = resolve_obs_pair(data_ref, data_query, adata_ref_name, adata_query_name) + backend = get_backend(flavour) + result = backend.align_obs(pair, device=device, **flavour_kwargs) + + return _writeback(pair, result, output_mode=output_mode, key_added=key_added, inplace=inplace) + + +def align_images( + sdata_ref: SpatialData, + sdata_query: SpatialData | None = None, + img_ref_name: str | None = None, + img_query_name: str | None = None, + flavour: Literal["stalign"] = "stalign", + *, + scale_ref: str | Literal["auto"] = "auto", + scale_query: str | Literal["auto"] = "auto", + output_mode: Literal["affine", "return"] = "affine", + device: Literal["cpu", "gpu"] | None = None, + inplace: bool = True, + **flavour_kwargs: Any, +) -> SpatialData | AlignResult | None: + """Align two raster images living inside :class:`spatialdata.SpatialData`. + + Parameters + ---------- + sdata_ref, sdata_query + SpatialData containers. Pass ``sdata_query=None`` to align two + images of the same SpatialData against each other. + img_ref_name, img_query_name + Image element keys. + flavour + Only ``'stalign'`` is currently supported. + scale_ref, scale_query + Scale level for multi-scale image elements. ``'auto'`` picks the + coarsest level. Single-scale images ignore this parameter. + output_mode + ``'affine'`` registers the fit on the query image element so all of + its scales inherit the transformation; ``'return'`` returns the raw + :class:`AlignResult`. + device, inplace, flavour_kwargs + See :func:`align_obs`. + """ + validate_required(name="img_ref_name", value=img_ref_name, when="calling `align_images`") + validate_required(name="img_query_name", value=img_query_name, when="calling `align_images`") + validate_flavour(flavour, allowed=ALLOWED_FLAVOURS_IMAGES, op="align_images") + validate_output_mode(output_mode, allowed=ALLOWED_OUTPUT_MODES_NONOBS, op="align_images") + + pair = resolve_image_pair( + sdata_ref, + sdata_query, + img_ref_name, + img_query_name, + scale_ref=scale_ref, + scale_query=scale_query, + ) + backend = get_backend(flavour) + result = backend.align_images(pair, device=device, **flavour_kwargs) + + return _writeback(pair, result, output_mode=output_mode, key_added=None, inplace=inplace) + + +def align_by_landmarks( + sdata_ref: SpatialData, + sdata_query: SpatialData | None = None, + cs_name_ref: str | None = None, + cs_name_query: str | None = None, + scale_ref: str | None = None, + scale_query: str | None = None, + landmarks_ref: tuple[tuple[float, float], ...] | None = None, + landmarks_query: tuple[tuple[float, float], ...] | None = None, + *, + model: Literal["similarity", "affine"] = "similarity", + output_mode: Literal["affine", "return"] = "affine", + inplace: bool = True, +) -> SpatialData | AlignResult | None: + """Align by a closed-form fit on user-provided landmarks. + + Pure NumPy under the hood — JAX is **not** required for this path. + + Parameters + ---------- + sdata_ref, sdata_query + SpatialData containers. Pass ``sdata_query=None`` to align two + coordinate systems of the same SpatialData against each other. + cs_name_ref, cs_name_query + Coordinate system names. + scale_ref, scale_query + Optional scale identifiers used purely for landmark-extent + validation: if you extracted your landmarks at a particular scale, + passing the same scale here lets us catch the "wrong scale" footgun + early. + landmarks_ref, landmarks_query + Equal-length sequences of ``(y, x)`` tuples. ``model='similarity'`` + needs ≥ 2 pairs, ``model='affine'`` needs ≥ 3. + model + ``'similarity'`` (rotation + uniform scale + translation) or + ``'affine'`` (full 6-parameter linear). + output_mode, inplace + See :func:`align_obs`. + """ + validate_required(name="cs_name_ref", value=cs_name_ref, when="calling `align_by_landmarks`") + validate_required(name="cs_name_query", value=cs_name_query, when="calling `align_by_landmarks`") + validate_required(name="landmarks_ref", value=landmarks_ref, when="calling `align_by_landmarks`") + validate_required(name="landmarks_query", value=landmarks_query, when="calling `align_by_landmarks`") + + validate_output_mode(output_mode, allowed=ALLOWED_OUTPUT_MODES_NONOBS, op="align_by_landmarks") + validate_landmark_model(model) + + # We don't materialise extents here in the skeleton; backends / a future + # PR can fill in the cs-extent lookup once we wire spatialdata.get_extent. + ref_arr, query_arr = validate_landmarks(landmarks_ref, landmarks_query, model=model) + + pair = resolve_element_pair(sdata_ref, sdata_query, cs_name_ref, cs_name_query) + + from squidpy.experimental.tl._align._backends._landmark import fit_landmark_affine + + affine = fit_landmark_affine( + ref_arr, + query_arr, + model=model, + source_cs=cs_name_query, + target_cs=cs_name_ref, + ) + result = AlignResult(transform=affine, metadata={"flavour": "landmark", "model": model}) + + return _writeback(pair, result, output_mode=output_mode, key_added=None, inplace=inplace) + + +# --------------------------------------------------------------------------- +# Internal: writeback dispatch +# --------------------------------------------------------------------------- + + +def _writeback( + pair: AlignPair, + result: AlignResult, + *, + output_mode: str, + key_added: str | None, + inplace: bool, +) -> AnnData | SpatialData | AlignResult | None: + if output_mode == "return": + return result + + if output_mode == "affine": + if not isinstance(result.transform, AffineTransform): + raise TypeError( + f"`output_mode='affine'` requires the backend to return an AffineTransform, " + f"got {type(result.transform).__name__}. Use `output_mode='obs'` (for " + f"`align_obs`) or `output_mode='return'` to access non-affine fits." + ) + return apply_affine_to_cs(pair, result.transform, inplace=inplace) + + if output_mode == "obs": + return materialise_obs(pair, result, key_added=key_added, inplace=inplace) + + raise ValueError(f"Unknown output_mode {output_mode!r}.") diff --git a/src/squidpy/experimental/tl/_align/_backends/__init__.py b/src/squidpy/experimental/tl/_align/_backends/__init__.py new file mode 100644 index 000000000..5a30e0eb6 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/__init__.py @@ -0,0 +1,26 @@ +"""Backend dispatch for the alignment skeleton. + +Imports of individual backends happen *inside* the dispatch branches so that +``import squidpy.experimental.tl`` never pulls in ``stalign``, ``moscot``, or +``jax`` transitively. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from squidpy.experimental.tl._align._backends._base import AlignBackend + + +def get_backend(flavour: str) -> AlignBackend: + """Return a backend instance for the requested ``flavour``.""" + if flavour == "stalign": + from squidpy.experimental.tl._align._backends._stalign import StAlignBackend + + return StAlignBackend() + if flavour == "moscot": + from squidpy.experimental.tl._align._backends._moscot import MoscotBackend + + return MoscotBackend() + raise ValueError(f"Unknown alignment flavour {flavour!r}; expected 'stalign' or 'moscot'.") diff --git a/src/squidpy/experimental/tl/_align/_backends/_base.py b/src/squidpy/experimental/tl/_align/_backends/_base.py new file mode 100644 index 000000000..58ba61657 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/_base.py @@ -0,0 +1,38 @@ +"""Backend Protocol shared by every alignment flavour.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, Protocol, runtime_checkable + +if TYPE_CHECKING: + from squidpy.experimental.tl._align._types import AlignPair, AlignResult + + +@runtime_checkable +class AlignBackend(Protocol): + """Minimal contract every alignment backend must satisfy. + + Backends are constructed cheaply (no heavy imports in ``__init__``) and + only pull in their optional dependencies on the first call into ``align_obs`` + or ``align_images``. ``requires_jax`` advertises whether the backend + needs JAX so callers / dispatch can short-circuit. + """ + + name: str + requires_jax: bool + + def align_obs( + self, + pair: AlignPair, + *, + device: Literal["cpu", "gpu"] | None = None, + **kwargs: Any, + ) -> AlignResult: ... + + def align_images( + self, + pair: AlignPair, + *, + device: Literal["cpu", "gpu"] | None = None, + **kwargs: Any, + ) -> AlignResult: ... diff --git a/src/squidpy/experimental/tl/_align/_backends/_landmark.py b/src/squidpy/experimental/tl/_align/_backends/_landmark.py new file mode 100644 index 000000000..11bf8475a --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/_landmark.py @@ -0,0 +1,106 @@ +"""Closed-form landmark fit. + +Two models, both pure NumPy / no JAX: + +- ``"similarity"`` (4 DOF: rotation + uniform scale + translation, plus an + optional reflection check) - delegated to + :func:`spatialdata.transformations.get_transformation_between_landmarks`. +- ``"affine"`` (6 DOF: rotation + non-uniform scale + shear + translation) - + delegated to :func:`skimage.transform.estimate_transform`, the same + least-squares solver spatialdata uses internally. + +Useful as a one-shot alignment when you already have corresponding landmarks, +and as a sanity-check baseline for the much heavier STalign LDDMM solver. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +import numpy as np + +if TYPE_CHECKING: + from squidpy.experimental.tl._align._types import AffineTransform + + +def fit_landmark_affine( + landmarks_ref: np.ndarray, + landmarks_query: np.ndarray, + *, + model: Literal["similarity", "affine"] = "similarity", + source_cs: str | None = None, + target_cs: str | None = None, +) -> AffineTransform: + """Fit a 2D affine that maps ``landmarks_query`` onto ``landmarks_ref``. + + Both inputs are ``(N, 2)`` ``(x, y)`` arrays of corresponding landmarks + (the ``i``-th row of ``landmarks_query`` matches the ``i``-th row of + ``landmarks_ref``). ``N`` must be at least 3. + + Parameters + ---------- + landmarks_ref, landmarks_query + Corresponding landmark coordinates in ``(x, y)`` convention. + model + ``"similarity"`` (4 DOF, via spatialdata) or ``"affine"`` (6 DOF, + via skimage). + source_cs, target_cs + Optional coordinate-system labels stamped onto the returned + :class:`AffineTransform` for traceability. + """ + from squidpy.experimental.tl._align._types import AffineTransform + + ref = np.asarray(landmarks_ref, dtype=float) + query = np.asarray(landmarks_query, dtype=float) + + if model == "similarity": + matrix = _fit_similarity_via_spatialdata(ref, query) + elif model == "affine": + matrix = _fit_affine_via_skimage(ref, query) + else: + raise ValueError(f"Unknown landmark `model={model!r}`; expected 'similarity' or 'affine'.") + + return AffineTransform(matrix=matrix, source_cs=source_cs, target_cs=target_cs) + + +def _fit_similarity_via_spatialdata(ref_xy: np.ndarray, query_xy: np.ndarray) -> np.ndarray: + """4-DOF similarity fit, delegated to spatialdata.""" + from spatialdata.models import PointsModel + from spatialdata.transformations import get_transformation_between_landmarks + + refs_pts = PointsModel.parse(ref_xy) + moving_pts = PointsModel.parse(query_xy) + sd_transform = get_transformation_between_landmarks(refs_pts, moving_pts) + return _extract_affine_matrix(sd_transform) + + +def _fit_affine_via_skimage(ref_xy: np.ndarray, query_xy: np.ndarray) -> np.ndarray: + """Full 6-DOF affine fit, delegated to skimage's least-squares solver. + + This is what :func:`spatialdata.transformations.get_transformation_between_landmarks` + uses under the hood before collapsing to a similarity; for the affine + model we keep the raw matrix instead. + """ + from skimage.transform import estimate_transform + + model_obj = estimate_transform("affine", src=query_xy, dst=ref_xy) + return np.asarray(model_obj.params) + + +def _extract_affine_matrix(sd_transform: object) -> np.ndarray: + """Pull a ``(3, 3)`` homogeneous matrix out of a spatialdata transformation. + + :func:`get_transformation_between_landmarks` may return either a single + :class:`spatialdata.transformations.Affine` or a + :class:`spatialdata.transformations.Sequence` of two affines (when a + reflection is detected and rolled into the fit). Use + ``to_affine_matrix`` to collapse either back to a single 3x3. + """ + from spatialdata.transformations import Affine as SDAffine + from spatialdata.transformations import Sequence as SDSequence + + if isinstance(sd_transform, SDAffine): + return np.asarray(sd_transform.matrix) + if isinstance(sd_transform, SDSequence): + return np.asarray(sd_transform.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))) + raise TypeError(f"Unexpected transformation type from spatialdata: {type(sd_transform).__name__}.") diff --git a/src/squidpy/experimental/tl/_align/_backends/_moscot.py b/src/squidpy/experimental/tl/_align/_backends/_moscot.py new file mode 100644 index 000000000..b5fe01f64 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/_moscot.py @@ -0,0 +1,45 @@ +"""Moscot backend stub. + +Moscot only exposes ``align_obs``; image alignment is not a moscot use case. +The dispatch layer rejects ``flavour='moscot'`` for :func:`align_images` +before ever reaching this file, so the ``align_images`` method below is here +purely to satisfy the :class:`AlignBackend` Protocol. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal + +if TYPE_CHECKING: + from squidpy.experimental.tl._align._types import AlignPair, AlignResult + + +class MoscotBackend: + name = "moscot" + requires_jax = True + + def align_obs( + self, + pair: AlignPair, + *, + device: Literal["cpu", "gpu"] | None = None, + **kwargs: Any, + ) -> AlignResult: + from squidpy.experimental.tl._align._jax import require_jax + + require_jax(device) + # Lazy moscot import lives here in the next PR: + # import moscot + raise NotImplementedError( + "moscot backend `align_obs`: TODO. Skeleton landed; the moscot " + "solver will replace this body in a follow-up PR." + ) + + def align_images( + self, + pair: AlignPair, + *, + device: Literal["cpu", "gpu"] | None = None, + **kwargs: Any, + ) -> AlignResult: + raise NotImplementedError("moscot does not implement image alignment; use `flavour='stalign'`.") diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign.py b/src/squidpy/experimental/tl/_align/_backends/_stalign.py new file mode 100644 index 000000000..6d7c25d8c --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign.py @@ -0,0 +1,102 @@ +"""STalign backend. + +Wraps the JAX LDDMM solver lifted from scverse/squidpy#1150 (Selman Özleyen) +into the :class:`AlignBackend` Protocol. Only ``align_obs`` is implemented +today; ``align_images`` raises until upstream support exists. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np + +if TYPE_CHECKING: + from squidpy.experimental.tl._align._types import AlignPair, AlignResult + + +class StAlignBackend: + name = "stalign" + requires_jax = True + + def align_obs( + self, + pair: AlignPair, + *, + device: Literal["cpu", "gpu"] | None = None, + config: Any | None = None, + landmarks_source: np.ndarray | None = None, + landmarks_target: np.ndarray | None = None, + **kwargs: Any, + ) -> AlignResult: + from anndata import AnnData + + from squidpy.experimental.tl._align._jax import require_jax + + # Resolve JAX *before* importing the lifted _tools module, because + # _tools does `import jax.numpy as jnp` at module level. If we let + # that import fire first, callers without JAX get a confusing + # `ModuleNotFoundError: import of jax halted; None in sys.modules` + # instead of the clean `ImportError("JAX is required ...")` from + # _jax.require_jax. + require_jax(device) + + from squidpy.experimental.tl._align._backends._stalign_tools import stalign_points + from squidpy.experimental.tl._align._types import AlignResult, ObsDisplacement + + if not isinstance(pair.ref, AnnData) or not isinstance(pair.query, AnnData): + raise TypeError( + "stalign backend `align_obs` only supports AnnData / table inputs; " + f"got ref={type(pair.ref).__name__}, query={type(pair.query).__name__}." + ) + if "spatial" not in pair.query.obsm or "spatial" not in pair.ref.obsm: + raise KeyError("Both ref and query must carry an `obsm['spatial']` point cloud.") + + src_xy = np.asarray(pair.query.obsm["spatial"], dtype=float) + tgt_xy = np.asarray(pair.ref.obsm["spatial"], dtype=float) + + # stalign_points runs internally in row_col (yx); obsm["spatial"] is xy + # by squidpy convention. Mirror Selman's _stalign.py:69-70 / :90 swap. + src_rc = src_xy[:, [1, 0]] + tgt_rc = tgt_xy[:, [1, 0]] + landmarks_src_rc = None if landmarks_source is None else np.asarray(landmarks_source)[:, [1, 0]] + landmarks_tgt_rc = None if landmarks_target is None else np.asarray(landmarks_target)[:, [1, 0]] + + stalign_result = stalign_points( + source_points=src_rc, + target_points=tgt_rc, + config=config, + landmarks_source=landmarks_src_rc, + landmarks_target=landmarks_tgt_rc, + ) + + aligned_rc = np.asarray(stalign_result.aligned_points) + aligned_xy = aligned_rc[:, [1, 0]] + deltas_xy = aligned_xy - src_xy + + return AlignResult( + transform=ObsDisplacement( + deltas=deltas_xy, + source_cs=pair.query_cs, + target_cs=pair.ref_cs, + ), + metadata={ + "flavour": "stalign", + # Escape hatch for power users who want the diffeomorphic + # part (velocity field, velocity grid, affine init) rather + # than just the materialised displacement. + "stalign_result": stalign_result, + }, + ) + + def align_images( + self, + pair: AlignPair, + *, + device: Literal["cpu", "gpu"] | None = None, + **kwargs: Any, + ) -> AlignResult: + raise NotImplementedError( + "stalign image alignment is not yet implemented; PR #1150 only ships " + "point-cloud alignment. Use `flavour='stalign'` with `align_obs`." + ) diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign_core.py b/src/squidpy/experimental/tl/_align/_backends/_stalign_core.py new file mode 100644 index 000000000..65e8ce72c --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign_core.py @@ -0,0 +1,367 @@ +"""Core JAX implementation for experimental STalign point registration. + +Lifted byte-for-byte from scverse/squidpy#1150 (Selman Özleyen). +""" + +from __future__ import annotations + +from typing import Any, Literal + +import jax +import jax.numpy as jnp +import jax.scipy as jsp +import numpy as np + +JAX_DTYPE = jnp.float64 if jax.config.x64_enabled else jnp.float32 +__all__ = ["JAX_DTYPE", "lddmm", "transform_points_row_col"] + + +def _to_affine(linear: Any, translation: Any) -> Any: + return jnp.array( + [ + [linear[0, 0], linear[0, 1], translation[0]], + [linear[1, 0], linear[1, 1], translation[1]], + [0.0, 0.0, 1.0], + ], + dtype=linear.dtype, + ) + + +def _grid_points(x: tuple[Any, Any]) -> Any: + yy, xx = jnp.meshgrid(x[0], x[1], indexing="ij") + return jnp.stack((yy, xx)) + + +def _interp( + x: tuple[Any, Any], + image: Any, + phii: Any, + *, + mode: str = "nearest", +) -> Any: + """Interpolate a channels-first image on physical row-column coordinates.""" + arr = jnp.asarray(image) + coords = jnp.asarray(phii) + if coords.shape[0] != 2: + raise ValueError(f"Expected interpolation coordinates to have leading axis of size 2, found `{coords.shape}`.") + if arr.ndim == 2: + arr = arr[None, ...] + + row_step = x[0][1] - x[0][0] + col_step = x[1][1] - x[1][0] + row_idx = (coords[0] - x[0][0]) / row_step + col_idx = (coords[1] - x[1][0]) / col_step + idx = jnp.stack((row_idx.reshape(-1), col_idx.reshape(-1))) + + def _sample(channel: Any) -> Any: + values = jsp.ndimage.map_coordinates(channel, idx, order=1, mode=mode) + return values.reshape(coords.shape[1:]) + + return jax.vmap(_sample)(arr) + + +def transform_points_row_col( + xv: tuple[Any, Any], + velocity: Any, + affine: Any, + points: Any, + *, + direction: Literal["forward", "backward"] = "forward", +) -> Any: + pts = jnp.asarray(points) + n_steps = velocity.shape[0] + time_steps = range(n_steps) + flow_sign = 1.0 + if direction == "backward": + affine = jnp.linalg.inv(affine) + pts = pts @ affine[:2, :2].T + affine[:2, -1] + flow_sign = -1.0 + time_steps = reversed(time_steps) + + for t in time_steps: + disp = _interp( + xv, + jnp.moveaxis(flow_sign * velocity[t], -1, 0), + pts.T[:, :, None], + mode="nearest", + )[:, :, 0].T + pts = pts + disp / n_steps + + if direction == "forward": + pts = pts @ affine[:2, :2].T + affine[:2, -1] + + return pts + + +def _transform_grid_backward( + x_target: tuple[Any, Any], + xv: tuple[Any, Any], + velocity: Any, + affine: Any, +) -> Any: + target_grid = _grid_points(x_target) + affine_inv = jnp.linalg.inv(affine) + source_grid = jnp.einsum("ij,jhw->ihw", affine_inv[:2, :2], target_grid) + affine_inv[:2, -1][:, None, None] + + for t in range(velocity.shape[0] - 1, -1, -1): + disp = _interp(xv, jnp.moveaxis(-velocity[t], -1, 0), source_grid, mode="nearest") + source_grid = source_grid + disp / velocity.shape[0] + + return source_grid + + +def _contrast_transform(source_image: Any, target_image: Any, weights: Any) -> Any: + flat_source = source_image.reshape(source_image.shape[0], -1) + flat_target = target_image.reshape(target_image.shape[0], -1) + flat_weights = weights.reshape(-1) + + design = jnp.concatenate((jnp.ones((1, flat_source.shape[1]), dtype=source_image.dtype), flat_source), axis=0) + weighted_design = design * flat_weights[None, :] + design_cov = weighted_design @ design.T + target_cov = weighted_design @ flat_target.T + regularized = design_cov + 0.1 * jnp.eye(design_cov.shape[0], dtype=design_cov.dtype) + coefficients = jnp.linalg.solve(regularized, target_cov) + return (coefficients.T @ design).reshape(target_image.shape) + + +def _build_velocity_grid(x_source: tuple[Any, Any], *, a: float, expand: float) -> tuple[Any, Any]: + minimum = jnp.array([x_source[0][0], x_source[1][0]]) + maximum = jnp.array([x_source[0][-1], x_source[1][-1]]) + center = (minimum + maximum) / 2.0 + half_width = (maximum - minimum) * expand / 2.0 + step = a * 0.5 + return ( + jnp.arange(center[0] - half_width[0], center[0] + half_width[0] + step, step), + jnp.arange(center[1] - half_width[1], center[1] + half_width[1] + step, step), + ) + + +def _build_regularizer( + xv: tuple[Any, Any], + *, + a: float, + p: float, +) -> tuple[Any, Any, Any]: + dv = jnp.array([xv[0][1] - xv[0][0], xv[1][1] - xv[1][0]]) + shape = (xv[0].shape[0], xv[1].shape[0]) + fy = jnp.arange(shape[0], dtype=xv[0].dtype) / (shape[0] * dv[0]) + fx = jnp.arange(shape[1], dtype=xv[1].dtype) / (shape[1] * dv[1]) + frequency_grid = jnp.stack(jnp.meshgrid(fy, fx, indexing="ij"), axis=-1) + ll = (1.0 + 2.0 * a**2 * jnp.sum((1.0 - jnp.cos(2.0 * np.pi * frequency_grid * dv)) / (dv**2), axis=-1)) ** ( + 2.0 * p + ) + kernel = 1.0 / ll + dv_prod = jnp.prod(dv) + return kernel, ll, dv_prod + + +def _update_mixture_weights( + transformed_source: Any, + target_image: Any, + match_weights: Any, + artifact_weights: Any, + background_weights: Any, + *, + sigmaM: float, + sigmaA: float, + sigmaB: float, + estimate_muA: bool, + estimate_muB: bool, + muA: Any, + muB: Any, + iteration: int, +) -> tuple[Any, Any, Any, Any, Any]: + if estimate_muA: + muA = jnp.sum(artifact_weights * target_image, axis=(-1, -2)) / jnp.maximum(jnp.sum(artifact_weights), 1e-12) + if estimate_muB: + muB = jnp.sum(background_weights * target_image, axis=(-1, -2)) / jnp.maximum( + jnp.sum(background_weights), 1e-12 + ) + + if iteration < 50: + return match_weights, artifact_weights, background_weights, muA, muB + + weights = jnp.stack((match_weights, artifact_weights, background_weights)) + mixing = jnp.sum(weights, axis=(1, 2)) + mixing = mixing + jnp.max(mixing) * 1e-6 + mixing = mixing / jnp.sum(mixing) + + n_channels = target_image.shape[0] + norm_match = (2.0 * np.pi * sigmaM**2) ** (n_channels / 2.0) + norm_artifact = (2.0 * np.pi * sigmaA**2) ** (n_channels / 2.0) + norm_background = (2.0 * np.pi * sigmaB**2) ** (n_channels / 2.0) + + match_weights = mixing[0] * jnp.exp(-jnp.sum((transformed_source - target_image) ** 2, axis=0) / (2.0 * sigmaM**2)) + match_weights = match_weights / norm_match + artifact_weights = mixing[1] * jnp.exp( + -jnp.sum((muA[:, None, None] - target_image) ** 2, axis=0) / (2.0 * sigmaA**2) + ) + artifact_weights = artifact_weights / norm_artifact + background_weights = mixing[2] * jnp.exp( + -jnp.sum((muB[:, None, None] - target_image) ** 2, axis=0) / (2.0 * sigmaB**2) + ) + background_weights = background_weights / norm_background + + total = match_weights + artifact_weights + background_weights + total = total + jnp.max(total) * 1e-6 + return match_weights / total, artifact_weights / total, background_weights / total, muA, muB + + +def _lddmm_loss( + linear: Any, + translation: Any, + velocity: Any, + *, + x_source: tuple[Any, Any], + source_image: Any, + x_target: tuple[Any, Any], + target_image: Any, + xv: tuple[Any, Any], + match_weights: Any, + ll: Any, + dv_prod: Any, + points_source: Any, + points_target: Any, + sigmaM: float, + sigmaR: float, + sigmaP: float, +) -> tuple[Any, tuple[Any, Any, Any, Any, Any]]: + affine = _to_affine(linear, translation) + source_grid = _transform_grid_backward(x_target, xv, velocity, affine) + warped_source = _interp(x_source, source_image, source_grid, mode="nearest") + contrast_source = _contrast_transform(warped_source, target_image, match_weights) + + match_energy = jnp.sum((contrast_source - target_image) ** 2 * match_weights) / (2.0 * sigmaM**2) + fft_velocity = jnp.fft.fftn(velocity, axes=(1, 2)) + reg_energy = ( + jnp.sum(jnp.sum(jnp.abs(fft_velocity) ** 2, axis=(0, 3)) * ll) + * dv_prod + / 2.0 + / velocity.shape[1] + / velocity.shape[2] + / sigmaR**2 + ) + + transformed_points = transform_points_row_col(xv, velocity, affine, points_source, direction="forward") + if points_source.shape[0] == 0: + point_energy = jnp.array(0.0, dtype=source_image.dtype) + else: + point_energy = jnp.sum((transformed_points - points_target) ** 2) / (2.0 * sigmaP**2) + + total = match_energy + reg_energy + point_energy + return total, (contrast_source, transformed_points, match_energy, reg_energy, point_energy) + + +def lddmm( + xI: tuple[Any, Any], + I: Any, + xJ: tuple[Any, Any], + J: Any, + *, + L: Any, + T: Any, + points_source: Any | None = None, + points_target: Any | None = None, + a: float = 500.0, + p: float = 2.0, + expand: float = 2.0, + nt: int = 3, + niter: int = 5000, + diffeo_start: int = 0, + epL: float = 2e-8, + epT: float = 2e-1, + epV: float = 2e3, + sigmaM: float = 1.0, + sigmaB: float = 2.0, + sigmaA: float = 5.0, + sigmaR: float = 5e5, + sigmaP: float = 2e1, +) -> dict[str, Any]: + x_source = (jnp.asarray(xI[0]), jnp.asarray(xI[1])) + x_target = (jnp.asarray(xJ[0]), jnp.asarray(xJ[1])) + source_image = jnp.asarray(I, dtype=JAX_DTYPE) + target_image = jnp.asarray(J, dtype=JAX_DTYPE) + linear = jnp.asarray(L, dtype=JAX_DTYPE) + translation = jnp.asarray(T, dtype=JAX_DTYPE) + + if points_source is None: + source_landmarks = jnp.zeros((0, 2), dtype=JAX_DTYPE) + target_landmarks = jnp.zeros((0, 2), dtype=JAX_DTYPE) + else: + source_landmarks = jnp.asarray(points_source, dtype=JAX_DTYPE) + target_landmarks = jnp.asarray(points_target, dtype=JAX_DTYPE) + + xv = _build_velocity_grid(x_source, a=a, expand=expand) + velocity = jnp.zeros((nt, xv[0].shape[0], xv[1].shape[0], 2), dtype=JAX_DTYPE) + kernel, ll, dv_prod = _build_regularizer(xv, a=a, p=p) + + match_weights = jnp.full(target_image.shape[1:], 0.5, dtype=target_image.dtype) + background_weights = jnp.full(target_image.shape[1:], 0.4, dtype=target_image.dtype) + artifact_weights = jnp.full(target_image.shape[1:], 0.1, dtype=target_image.dtype) + muA = jnp.mean(target_image, axis=(1, 2)) + muB = jnp.zeros_like(muA) + estimate_muA = True + estimate_muB = True + + loss_and_grad = jax.jit(jax.value_and_grad(_lddmm_loss, argnums=(0, 1, 2), has_aux=True)) + + for iteration in range(niter): + (energy, aux), (grad_linear, grad_translation, grad_velocity) = loss_and_grad( + linear, + translation, + velocity, + x_source=x_source, + source_image=source_image, + x_target=x_target, + target_image=target_image, + xv=xv, + match_weights=match_weights, + ll=ll, + dv_prod=dv_prod, + points_source=source_landmarks, + points_target=target_landmarks, + sigmaM=sigmaM, + sigmaR=sigmaR, + sigmaP=sigmaP, + ) + contrast_source, transformed_points, _, _, _ = aux + + affine_scale = 1.0 + 9.0 * float(iteration >= diffeo_start) + linear = linear - (epL / affine_scale) * grad_linear + translation = translation - (epT / affine_scale) * grad_translation + + grad_velocity = jnp.fft.ifftn( + jnp.fft.fftn(grad_velocity, axes=(1, 2)) * kernel[None, ..., None], + axes=(1, 2), + ).real + if iteration >= diffeo_start: + velocity = velocity - epV * grad_velocity + + if iteration % 5 == 0: + match_weights, artifact_weights, background_weights, muA, muB = _update_mixture_weights( + contrast_source, + target_image, + match_weights, + artifact_weights, + background_weights, + sigmaM=sigmaM, + sigmaA=sigmaA, + sigmaB=sigmaB, + estimate_muA=estimate_muA, + estimate_muB=estimate_muB, + muA=muA, + muB=muB, + iteration=iteration, + ) + + affine = _to_affine(linear, translation) + return { + "A": affine, + "v": velocity, + "xv": xv, + "WM": match_weights, + "WB": background_weights, + "WA": artifact_weights, + "E": energy, + "points": transformed_points, + } diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign_helpers.py b/src/squidpy/experimental/tl/_align/_backends/_stalign_helpers.py new file mode 100644 index 000000000..7e5e5099e --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign_helpers.py @@ -0,0 +1,187 @@ +"""Helpers for experimental STalign point-cloud registration. + +Lifted byte-for-byte from scverse/squidpy#1150 (Selman Özleyen). +""" + +from __future__ import annotations + +from typing import Literal + +import numpy as np +from anndata import AnnData + +PointOrder = Literal["row_col", "xy"] + +__all__ = [ + "PointOrder", + "affine_from_points", + "extract_landmarks", + "extract_points", + "rasterize", +] + + +def _validate_points(points: np.ndarray, *, name: str) -> np.ndarray: + arr = np.asarray(points, dtype=float) + if arr.ndim != 2 or arr.shape[1] != 2: + raise ValueError(f"Expected `{name}` to have shape `(n, 2)`, found `{arr.shape}`.") + if not np.all(np.isfinite(arr)): + raise ValueError(f"Expected `{name}` to contain only finite values.") + return arr + + +def extract_points(adata: AnnData, key: str = "spatial") -> np.ndarray: + """Return a validated coordinate array from ``adata.obsm``.""" + if key not in adata.obsm: + raise KeyError(f"Key `{key}` not found in `adata.obsm`.") + + return _validate_points(np.asarray(adata.obsm[key]), name=f"adata.obsm[{key!r}]") + + +def extract_landmarks(adata: AnnData, key: str) -> np.ndarray: + """Return a validated landmark array from ``adata.obsm`` or ``adata.uns``.""" + if key in adata.obsm: + arr = np.asarray(adata.obsm[key], dtype=float) + if arr.ndim != 2 or arr.shape[1] != 2: + raise ValueError(f"Expected `adata.obsm[{key!r}]` to have shape `(n, 2)`, found `{arr.shape}`.") + mask = np.all(np.isfinite(arr), axis=1) + landmarks = arr[mask] + if landmarks.size == 0: + raise ValueError(f"No finite landmark rows were found in `adata.obsm[{key!r}]`.") + return landmarks + + if key in adata.uns: + return _validate_points(np.asarray(adata.uns[key]), name=f"adata.uns[{key!r}]") + + raise KeyError(f"Key `{key}` not found in `adata.obsm` or `adata.uns`.") + + +# TODO: are these duplicated? I would imagine its +# better to keep image transform functions under some place + + +def to_row_col(points: np.ndarray, *, point_order: PointOrder) -> np.ndarray: + """Convert coordinates to row-column order.""" + arr = _validate_points(points, name="points") + if point_order == "row_col": + return arr + if point_order == "xy": + return arr[:, [1, 0]] + raise ValueError(f"Unknown `point_order`: `{point_order}`.") + + +def from_row_col(points: np.ndarray, *, point_order: PointOrder) -> np.ndarray: + """Convert row-column coordinates to the requested order.""" + arr = _validate_points(points, name="points") + if point_order == "row_col": + return arr + if point_order == "xy": + return arr[:, [1, 0]] + raise ValueError(f"Unknown `point_order`: `{point_order}`.") + + +def _normalize(values: np.ndarray) -> np.ndarray: + values = np.asarray(values, dtype=float) + vmin = np.min(values) + vmax = np.max(values) + if np.isclose(vmin, vmax): + return np.ones_like(values, dtype=float) + return (values - vmin) / (vmax - vmin) + + +def rasterize( + x: np.ndarray, + y: np.ndarray, + *, + g: np.ndarray | None = None, + dx: float = 30.0, + blur: float | list[float] = 1.0, + expand: float = 1.1, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Rasterize a point cloud into a multi-scale density image.""" + x = np.asarray(x, dtype=float) + y = np.asarray(y, dtype=float) + if x.ndim != 1 or y.ndim != 1 or x.shape != y.shape: + raise ValueError("Expected `x` and `y` to be 1D arrays with the same length.") + if x.size == 0: + raise ValueError("Expected at least one point to rasterize.") + if dx <= 0: + raise ValueError("Expected `dx` to be positive.") + if expand <= 0: + raise ValueError("Expected `expand` to be positive.") + + blur_values = np.atleast_1d(np.asarray(blur, dtype=float)) + if blur_values.ndim != 1 or np.any(blur_values <= 0): + raise ValueError("Expected `blur` to be a positive scalar or a 1D sequence of positive values.") + + if g is None: + weights = np.ones_like(x, dtype=float) + else: + weights = np.asarray(g, dtype=float) + if weights.shape != x.shape: + raise ValueError("Expected `g` to have the same shape as `x` and `y`.") + if not np.allclose(weights, 1.0): + weights = _normalize(weights) + + min_x = float(np.min(x)) + max_x = float(np.max(x)) + min_y = float(np.min(y)) + max_y = float(np.max(y)) + + center_x = (min_x + max_x) / 2.0 + center_y = (min_y + max_y) / 2.0 + half_x = (max_x - min_x) * expand / 2.0 + half_y = (max_y - min_y) * expand / 2.0 + + grid_x = np.arange(center_x - half_x, center_x + half_x + dx, dx, dtype=float) + grid_y = np.arange(center_y - half_y, center_y + half_y + dx, dx, dtype=float) + if grid_x.size < 2 or grid_y.size < 2: + raise ValueError("Rasterized grid is too small. Increase the point spread or lower `dx`.") + + mesh_x, mesh_y = np.meshgrid(grid_x, grid_y) + out = np.zeros((len(blur_values), grid_y.size, grid_x.size), dtype=float) + radius = int(np.ceil(float(np.max(blur_values)) * 4.0)) + + for x_i, y_i, w_i in zip(x, y, weights, strict=False): + col = int(np.rint((x_i - grid_x[0]) / dx)) + row = int(np.rint((y_i - grid_y[0]) / dx)) + + row0 = max(row - radius, 0) + row1 = min(row + radius, out.shape[1] - 1) + col0 = max(col - radius, 0) + col1 = min(col + radius, out.shape[2] - 1) + + patch_x = mesh_x[row0 : row1 + 1, col0 : col1 + 1] + patch_y = mesh_y[row0 : row1 + 1, col0 : col1 + 1] + denom = 2.0 * (dx * blur_values * 2.0) ** 2 + + kernels = np.exp(-((patch_x[..., None] - x_i) ** 2 + (patch_y[..., None] - y_i) ** 2) / denom) + kernels_sum = kernels.sum(axis=(0, 1), keepdims=True) + kernels /= np.where(kernels_sum == 0.0, 1.0, kernels_sum) + out[:, row0 : row1 + 1, col0 : col1 + 1] += np.moveaxis(kernels * w_i, -1, 0) + + return grid_x, grid_y, out + + +def affine_from_points( + points_source: np.ndarray, + points_target: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Compute an affine initialization from corresponding landmarks.""" + source = _validate_points(points_source, name="points_source") + target = _validate_points(points_target, name="points_target") + if source.shape != target.shape: + raise ValueError( + f"Expected `points_source` and `points_target` to have the same shape, found " + f"`{source.shape}` and `{target.shape}`." + ) + + if source.shape[0] < 3: + linear = np.eye(2, dtype=float) + translation = np.mean(target, axis=0) - np.mean(source, axis=0) + return linear, translation + + source_h = np.concatenate((source, np.ones((source.shape[0], 1), dtype=float)), axis=1) + target_h = np.concatenate((target, np.ones((target.shape[0], 1), dtype=float)), axis=1) + affine = np.linalg.lstsq(source_h, target_h, rcond=None)[0].T + return affine[:2, :2], affine[:2, -1] diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py b/src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py new file mode 100644 index 000000000..c2ba666d6 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py @@ -0,0 +1,261 @@ +"""Low-level point-cloud tools for experimental STalign. + +Lifted from scverse/squidpy#1150 (Selman Özleyen). Only the two import paths +below were rewritten to point at the sibling lifted modules; the rest of the +file is byte-for-byte identical to the upstream PR. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal, TypeAlias + +import jax.numpy as jnp +import numpy as np +from anndata import AnnData + +from squidpy.experimental.tl._align._backends._stalign_core import JAX_DTYPE, lddmm, transform_points_row_col +from squidpy.experimental.tl._align._backends._stalign_helpers import ( + PointOrder, + affine_from_points, + extract_points, + from_row_col, + rasterize, + to_row_col, +) + +if TYPE_CHECKING: + import jax + + JaxArray = jax.Array +else: # pragma: no cover - typing only + JaxArray = Any + +BlurScales: TypeAlias = float | tuple[float, ...] | list[float] + +__all__ = [ + "STalignConfig", + "STalignPreprocessConfig", + "STalignPreprocessResult", + "STalignRegistrationConfig", + "STalignResult", + "stalign_points", + "stalign_preprocess", + "transform_points", +] + + +@dataclass(slots=True) +class STalignPreprocessConfig: + dx: float = 30.0 + blur: BlurScales = (2.0, 1.0, 0.5) + expand: float = 1.1 + + +@dataclass(slots=True) +class STalignRegistrationConfig: + a: float = 500.0 + p: float = 2.0 + expand: float = 2.0 + nt: int = 3 + niter: int = 5000 + diffeo_start: int = 0 + epL: float = 2e-8 + epT: float = 2e-1 + epV: float = 2e3 + sigmaM: float = 1.0 + sigmaB: float = 2.0 + sigmaA: float = 5.0 + sigmaR: float = 5e5 + sigmaP: float = 2e1 + + +@dataclass(slots=True) +class STalignConfig: + preprocess: STalignPreprocessConfig = field(default_factory=STalignPreprocessConfig) + registration: STalignRegistrationConfig = field(default_factory=STalignRegistrationConfig) + + +@dataclass(slots=True) +class STalignPreprocessResult: + source_grid: tuple[np.ndarray, np.ndarray] + source_image: np.ndarray + target_grid: tuple[np.ndarray, np.ndarray] + target_image: np.ndarray + + +@dataclass(slots=True) +class STalignResult: + affine: JaxArray + velocity: JaxArray + velocity_grid: tuple[JaxArray, JaxArray] + aligned_points: JaxArray + point_order: PointOrder = "row_col" + + def transform_points( + self, + points: np.ndarray, + *, + direction: Literal["forward", "backward"] = "forward", + point_order: PointOrder | None = None, + ) -> JaxArray: + """Transform arbitrary point arrays with the fitted map.""" + return transform_points( + self.velocity_grid, + self.velocity, + self.affine, + points, + direction=direction, + point_order=self.point_order if point_order is None else point_order, + ) + + def transform_adata( + self, + adata: AnnData, + *, + spatial_key: str = "spatial", + key_added: str | None = None, + direction: Literal["forward", "backward"] = "forward", + inplace: bool = False, + ) -> np.ndarray | None: + """ + Apply the fitted transform to coordinates stored on an AnnData object. + + If ``inplace=False``, return the transformed coordinates without + modifying ``adata``. If ``inplace=True``, write the transformed + coordinates to ``adata.obsm[spatial_key]`` or ``adata.obsm[key_added]`` + and return ``None``. + """ + points = extract_points(adata, key=spatial_key) + transformed = np.asarray(self.transform_points(points, direction=direction, point_order="xy")) + if not inplace: + return transformed + + adata.obsm[spatial_key if key_added is None else key_added] = transformed + return None + + +def stalign_preprocess( + source_points: np.ndarray, + target_points: np.ndarray, + *, + config: STalignPreprocessConfig | None = None, +) -> STalignPreprocessResult: + """Rasterize source and target point clouds for LDDMM registration.""" + config = STalignPreprocessConfig() if config is None else config + source_points = to_row_col(source_points, point_order="row_col") + target_points = to_row_col(target_points, point_order="row_col") + + source_x, source_y, source_image = rasterize( + source_points[:, 1], + source_points[:, 0], + dx=config.dx, + blur=config.blur, + expand=config.expand, + ) + target_x, target_y, target_image = rasterize( + target_points[:, 1], + target_points[:, 0], + dx=config.dx, + blur=config.blur, + expand=config.expand, + ) + + return STalignPreprocessResult( + source_grid=(source_y, source_x), + source_image=source_image, + target_grid=(target_y, target_x), + target_image=target_image, + ) + + +def transform_points( + xv: tuple[JaxArray, JaxArray], + v: JaxArray, + A: JaxArray, + points: np.ndarray, + *, + direction: Literal["forward", "backward"] = "forward", + point_order: PointOrder = "row_col", +) -> JaxArray: + """Transform point arrays with a fitted STalign map.""" + points_rc = to_row_col(points, point_order=point_order) + transformed = transform_points_row_col( + xv, + jnp.asarray(v), + jnp.asarray(A), + jnp.asarray(points_rc, dtype=JAX_DTYPE), + direction=direction, + ) + return jnp.asarray(from_row_col(np.asarray(transformed), point_order=point_order)) + + +def stalign_points( + source_points: np.ndarray, + target_points: np.ndarray, + *, + preprocessed: STalignPreprocessResult | None = None, + config: STalignConfig | None = None, + landmarks_source: np.ndarray | None = None, + landmarks_target: np.ndarray | None = None, +) -> STalignResult: + """Align source point cloud to target with a JAX LDDMM solver.""" + config = STalignConfig() if config is None else config + registration = config.registration + source_points = to_row_col(source_points, point_order="row_col") + target_points = to_row_col(target_points, point_order="row_col") + if preprocessed is None: + preprocessed = stalign_preprocess(source_points, target_points, config=config.preprocess) + + if (landmarks_source is None) != (landmarks_target is None): + raise ValueError("Expected both landmark arrays to be provided together.") + + if landmarks_source is None: + linear = np.eye(2, dtype=float) + translation = np.zeros(2, dtype=float) + source_landmarks = None + target_landmarks = None + else: + source_landmarks = to_row_col(landmarks_source, point_order="row_col") + target_landmarks = to_row_col(landmarks_target, point_order="row_col") + linear, translation = affine_from_points(source_landmarks, target_landmarks) + + result = lddmm( + preprocessed.source_grid, + preprocessed.source_image, + preprocessed.target_grid, + preprocessed.target_image, + L=jnp.asarray(linear, dtype=JAX_DTYPE), + T=jnp.asarray(translation, dtype=JAX_DTYPE), + points_source=None if source_landmarks is None else jnp.asarray(source_landmarks, dtype=JAX_DTYPE), + points_target=None if target_landmarks is None else jnp.asarray(target_landmarks, dtype=JAX_DTYPE), + a=registration.a, + p=registration.p, + expand=registration.expand, + nt=registration.nt, + niter=registration.niter, + diffeo_start=registration.diffeo_start, + epL=registration.epL, + epT=registration.epT, + epV=registration.epV, + sigmaM=registration.sigmaM, + sigmaB=registration.sigmaB, + sigmaA=registration.sigmaA, + sigmaR=registration.sigmaR, + sigmaP=registration.sigmaP, + ) + aligned_points = transform_points( + result["xv"], + result["v"], + result["A"], + source_points, + direction="forward", + point_order="row_col", + ) + return STalignResult( + affine=result["A"], + velocity=result["v"], + velocity_grid=result["xv"], + aligned_points=aligned_points, + point_order="row_col", + ) diff --git a/src/squidpy/experimental/tl/_align/_io.py b/src/squidpy/experimental/tl/_align/_io.py new file mode 100644 index 000000000..3a98d2e98 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_io.py @@ -0,0 +1,309 @@ +"""Input resolvers and output writers for the alignment skeleton. + +This module is the *only* place that knows about the duck-typed +``AnnData | SpatialData`` argument shape of the public functions and the +``output_mode`` writeback strategies. Backends operate on the canonical +:class:`AlignPair` produced here. +""" + +from __future__ import annotations + +from typing import Literal + +import numpy as np +from anndata import AnnData +from spatialdata import SpatialData + +from squidpy._validators import assert_isinstance, assert_key_in_sdata +from squidpy.experimental.im._utils import get_element_data +from squidpy.experimental.tl._align._types import ( + AffineTransform, + AlignPair, + AlignResult, +) +from squidpy.experimental.tl._align._validation import ( + validate_required, + validate_unexpected, +) + +# --------------------------------------------------------------------------- +# Resolvers +# --------------------------------------------------------------------------- + + +def resolve_obs_pair( + data_ref: AnnData | SpatialData, + data_query: AnnData | SpatialData | None, + adata_ref_name: str | None, + adata_query_name: str | None, +) -> AlignPair: + """Normalise the arguments of :func:`align_obs` into an :class:`AlignPair`. + + See the table in the design doc for the exhaustive case matrix. In short: + + - both AnnData → use directly, ``adata_*_name`` must be ``None``; + - both SpatialData → ``adata_*_name`` required, extract from each; + - only ``data_ref`` is SpatialData (``data_query is None``) → both + ``adata_*_name`` required, extract from the same sdata; + - mixed AnnData/SpatialData → :class:`TypeError`; + - ``data_ref`` is AnnData with no ``data_query`` → :class:`ValueError`. + """ + if isinstance(data_ref, AnnData): + if data_query is None: + raise ValueError("`data_query` is required when `data_ref` is an AnnData.") + if not isinstance(data_query, AnnData): + raise TypeError( + f"Mixed AnnData/SpatialData inputs are not supported. " + f"`data_ref` is AnnData but `data_query` is {type(data_query).__name__}." + ) + validate_unexpected( + name="adata_ref_name", + value=adata_ref_name, + when="`data_ref` is a SpatialData", + hint="Both inputs are AnnData, so there is no table to look up by name.", + ) + validate_unexpected( + name="adata_query_name", + value=adata_query_name, + when="`data_query` is a SpatialData", + hint="Both inputs are AnnData, so there is no table to look up by name.", + ) + return AlignPair(ref=data_ref, query=data_query) + + if not isinstance(data_ref, SpatialData): + raise TypeError(f"`data_ref` must be AnnData or SpatialData, got {type(data_ref).__name__}.") + + if data_query is None: + sdata_query: SpatialData = data_ref + elif isinstance(data_query, SpatialData): + sdata_query = data_query + else: + raise TypeError( + f"Mixed AnnData/SpatialData inputs are not supported. " + f"`data_ref` is SpatialData but `data_query` is {type(data_query).__name__}." + ) + + validate_required(name="adata_ref_name", value=adata_ref_name, when="`data_ref` is a SpatialData") + validate_required(name="adata_query_name", value=adata_query_name, when="`data_query` is a SpatialData") + assert_key_in_sdata(data_ref, adata_ref_name, attr="tables") + assert_key_in_sdata(sdata_query, adata_query_name, attr="tables") + return AlignPair( + ref=data_ref.tables[adata_ref_name], + query=sdata_query.tables[adata_query_name], + ref_container=data_ref, + query_container=sdata_query, + ref_element_key=adata_ref_name, + query_element_key=adata_query_name, + ) + + +def resolve_image_pair( + sdata_ref: SpatialData, + sdata_query: SpatialData | None, + img_ref_name: str, + img_query_name: str, + *, + scale_ref: str | Literal["auto"] = "auto", + scale_query: str | Literal["auto"] = "auto", +) -> AlignPair: + """Normalise the arguments of :func:`align_images` into an :class:`AlignPair`. + + Both single-scale ``xr.DataArray`` and multi-scale ``xr.DataTree`` image + elements are accepted. Multiscale nodes are flattened via + :func:`squidpy.experimental.im._utils.get_element_data`, but the original + element node is remembered in the :class:`AlignPair` so the writer can + register the transformation on the parent so all scales inherit it. + """ + assert_isinstance(sdata_ref, SpatialData, name="sdata_ref") + if sdata_query is None: + sdata_query = sdata_ref + else: + assert_isinstance(sdata_query, SpatialData, name="sdata_query") + + assert_key_in_sdata(sdata_ref, img_ref_name, attr="images") + assert_key_in_sdata(sdata_query, img_query_name, attr="images") + + ref_node = sdata_ref.images[img_ref_name] + query_node = sdata_query.images[img_query_name] + + ref_data = get_element_data(ref_node, scale_ref, element_type="image", element_key=img_ref_name) + query_data = get_element_data(query_node, scale_query, element_type="image", element_key=img_query_name) + + return AlignPair( + ref=ref_data, + query=query_data, + ref_container=sdata_ref, + query_container=sdata_query, + ref_element_key=img_ref_name, + query_element_key=img_query_name, + ) + + +def resolve_element_pair( + sdata_ref: SpatialData, + sdata_query: SpatialData | None, + cs_name_ref: str, + cs_name_query: str, +) -> AlignPair: + """Normalise the arguments of :func:`align_by_landmarks` into an :class:`AlignPair`. + + No element data is materialised — landmark fitting only needs the + coordinate system names plus the landmark coordinates, which are + validated separately. The returned pair carries the containers and cs + names so the writer can call :func:`set_transformation` on the right + target. + """ + assert_isinstance(sdata_ref, SpatialData, name="sdata_ref") + if sdata_query is None: + sdata_query = sdata_ref + else: + assert_isinstance(sdata_query, SpatialData, name="sdata_query") + + _check_cs_exists(sdata_ref, cs_name_ref, name="cs_name_ref") + _check_cs_exists(sdata_query, cs_name_query, name="cs_name_query") + + return AlignPair( + ref=None, + query=None, + ref_container=sdata_ref, + query_container=sdata_query, + ref_cs=cs_name_ref, + query_cs=cs_name_query, + ) + + +def _check_cs_exists(sdata: SpatialData, cs_name: str, *, name: str) -> None: + available = list(sdata.coordinate_systems) + if cs_name not in available: + raise KeyError( + f"`{name}={cs_name!r}` is not a coordinate system of the SpatialData object. " + f"Available coordinate systems: {available}." + ) + + +# --------------------------------------------------------------------------- +# Writeback +# --------------------------------------------------------------------------- + + +def apply_affine_to_cs( + pair: AlignPair, + affine: AffineTransform, + *, + inplace: bool, +) -> SpatialData | AnnData | None: + """Register ``affine`` on the query side of the pair. + + Three writeback paths, in order of specificity: + + 1. **Element-keyed**: ``pair.query_container`` and ``pair.query_element_key`` + are both set (e.g. ``align_obs`` / ``align_images`` resolved an explicit + table or image). Register the transform on that single element so all + scales / sibling tables that share its parent element node inherit it. + 2. **Cs-keyed**: only ``pair.query_cs`` is set (e.g. ``align_by_landmarks`` + resolved a coordinate system but no specific element). Walk every + element that has the moving cs in its transformation graph and register + the transform on each, mapping into the reference cs. + 3. **Plain AnnData**: no spatialdata container at all - warp + ``query.obsm['spatial']`` directly. + """ + from spatialdata.transformations import get_transformation, set_transformation + + target_cs = affine.target_cs or pair.ref_cs or "aligned" + + if pair.query_container is not None and pair.query_element_key is not None: + sdata = pair.query_container if inplace else _shallow_copy_sdata(pair.query_container) + element = sdata[pair.query_element_key] + set_transformation(element, affine.to_spatialdata(), to_coordinate_system=target_cs) + return None if inplace else sdata + + if pair.query_container is not None and pair.query_cs is not None: + sdata = pair.query_container if inplace else _shallow_copy_sdata(pair.query_container) + moving_cs = pair.query_cs + sd_affine = affine.to_spatialdata() + touched_any = False + for _etype, _name, element in sdata._gen_elements(include_tables=False): + element_transforms = get_transformation(element, get_all=True) + if moving_cs not in element_transforms: + continue + set_transformation(element, sd_affine, to_coordinate_system=target_cs) + touched_any = True + if not touched_any: + raise KeyError( + f"No elements in the query SpatialData are registered to coordinate " + f"system {moving_cs!r}; nothing to attach the alignment to." + ) + return None if inplace else sdata + + if isinstance(pair.query, AnnData): + adata = pair.query if inplace else pair.query.copy() + if "spatial" not in adata.obsm: + raise KeyError("Cannot apply an affine to an AnnData query that has no `obsm['spatial']`.") + adata.obsm["spatial"] = affine.apply(np.asarray(adata.obsm["spatial"])) + return None if inplace else adata + + raise RuntimeError("apply_affine_to_cs: pair has neither a SpatialData container nor an AnnData query.") + + +def materialise_obs( + pair: AlignPair, + result: AlignResult, + *, + key_added: str | None, + inplace: bool, +) -> SpatialData | AnnData | None: + """Bake the transform into a *new* AnnData living in the reference cs. + + For affine results we apply the matrix; for :class:`ObsDisplacement` we + add the deltas. When the source query lives inside a SpatialData, the new + AnnData is registered as ``sdata.tables[key_added]``; otherwise it is + returned directly. + """ + if not isinstance(pair.query, AnnData): + raise TypeError("materialise_obs only works for `align_obs`; `pair.query` must be an AnnData.") + if "spatial" not in pair.query.obsm: + raise KeyError("Source AnnData has no `obsm['spatial']` to warp.") + + src_coords = np.asarray(pair.query.obsm["spatial"]) + new_coords = result.transform.apply(src_coords) + + # Slim copy: share X/var/obs structurally and only rewrite obsm so we + # don't pay the cost of deep-copying potentially-large layers/obsp. + new_obsm = dict(pair.query.obsm) + new_obsm["spatial"] = new_coords + new_uns = dict(pair.query.uns) + new_uns["align"] = { + "source_query_key": pair.query_element_key, + "ref_key": pair.ref_element_key, + **result.metadata, + } + new_adata = AnnData( + X=pair.query.X, + obs=pair.query.obs.copy(), + var=pair.query.var, + obsm=new_obsm, + uns=new_uns, + ) + + if pair.query_container is not None: + if key_added is None: + raise ValueError("`key_added` is required when `output_mode='obs'` and the query is a SpatialData.") + sdata = pair.query_container if inplace else _shallow_copy_sdata(pair.query_container) + from spatialdata.models import TableModel + + sdata.tables[key_added] = TableModel.parse(new_adata) + return None if inplace else sdata + + return new_adata + + +def _shallow_copy_sdata(sdata: SpatialData) -> SpatialData: + """Shallow copy of a SpatialData object for ``inplace=False`` writeback paths. + + Uses :meth:`SpatialData.subset` over every element so tables and + ``attrs`` propagate the same way spatialdata's own subsetting handles + them, rather than reconstructing via the ``__init__`` constructor. + """ + element_names = [name for _, name, _ in sdata._gen_elements(include_tables=True)] + return sdata.subset(element_names, filter_tables=False, include_orphan_tables=True) + diff --git a/src/squidpy/experimental/tl/_align/_jax.py b/src/squidpy/experimental/tl/_align/_jax.py new file mode 100644 index 000000000..34411cd82 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_jax.py @@ -0,0 +1,56 @@ +"""Lazy JAX import + device selection for JAX-backed alignment backends. + +JAX is an optional dependency. Importing this module is cheap; calling +:func:`require_jax` is what actually pulls JAX in, and only the +JAX-backed backends do so on first call. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal + +if TYPE_CHECKING: + Device = Any # jax.Device, but importing it eagerly defeats the purpose + + +_INSTALL_HINT = ( + "JAX is required for the requested align_* flavour. " + "Install with `pip install jax` (CPU) or follow the JAX install guide for GPU." +) + + +def require_jax(device: Literal["cpu", "gpu"] | None = None) -> tuple[Any, Any]: + """Import JAX lazily and return ``(jax, device)``. + + Parameters + ---------- + device + ``"cpu"``/``"gpu"`` to force a platform, or ``None`` to use whatever + JAX picks as the default. + + Returns + ------- + jax_module + The imported :mod:`jax` module. + device + A :class:`jax.Device` of the requested platform. + + Raises + ------ + ImportError + If JAX is not installed. + RuntimeError + If the requested device platform is not available on this host. + """ + try: + import jax + except ImportError as e: + raise ImportError(_INSTALL_HINT) from e + + if device is None: + return jax, jax.devices()[0] + + matching = [d for d in jax.devices() if d.platform == device] + if not matching: + raise RuntimeError(f"No JAX device of kind {device!r} available; have {[d.platform for d in jax.devices()]}.") + return jax, matching[0] diff --git a/src/squidpy/experimental/tl/_align/_types.py b/src/squidpy/experimental/tl/_align/_types.py new file mode 100644 index 000000000..46e85d528 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_types.py @@ -0,0 +1,120 @@ +"""Dataclasses for the alignment skeleton. + +These types are the contract between the public ``align_*`` functions, the +input/output helpers in :mod:`squidpy.experimental.tl._align._io`, and the +backend implementations in :mod:`squidpy.experimental.tl._align._backends`. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import numpy as np + +if TYPE_CHECKING: + import xarray as xr + from anndata import AnnData + from spatialdata import SpatialData + from spatialdata.transformations import Affine + + +@dataclass(frozen=True) +class AlignPair: + """Canonical pair of aligned-or-to-be-aligned elements. + + Returned by every resolver in :mod:`._io`. ``ref``/``query`` carry the + actual data to fit on; ``*_container``/``*_element_key`` remember where + they came from so the writeback step can register a transformation on the + correct element of the correct :class:`spatialdata.SpatialData`. + """ + + # ``ref``/``query`` are ``None`` for landmark-only flows where the fit + # operates on user-provided coordinates and the resolver only needs to + # locate the target containers + coordinate systems. + ref: AnnData | xr.DataArray | None + query: AnnData | xr.DataArray | None + ref_container: SpatialData | None = None + query_container: SpatialData | None = None + ref_element_key: str | None = None + query_element_key: str | None = None + ref_cs: str | None = None + query_cs: str | None = None + + +@dataclass(frozen=True) +class AffineTransform: + """A ``(3, 3)`` homogeneous affine in ``(x, y)`` convention. + + This matches the coordinate axis order spatialdata uses for points - + ``spatialdata.transformations.get_transformation_between_landmarks`` + asserts ``axes == ("x", "y")`` - and the order squidpy / scanpy use for + ``adata.obsm["spatial"]``. Image elements are stored ``(c, y, x)`` in + spatialdata, so when registering an ``AffineTransform`` on an *image* + element you may need a separate matrix; this skeleton currently only + deals with point coordinates. + """ + + matrix: np.ndarray + source_cs: str | None = None + target_cs: str | None = None + + def __post_init__(self) -> None: + if self.matrix.shape != (3, 3): + raise ValueError(f"Expected a (3, 3) homogeneous matrix, got shape {self.matrix.shape}.") + + def to_spatialdata(self) -> Affine: + """Build a :class:`spatialdata.transformations.Affine` for writeback.""" + from spatialdata.transformations import Affine + + return Affine( + self.matrix, + input_axes=("x", "y"), + output_axes=("x", "y"), + ) + + def apply(self, coords: np.ndarray) -> np.ndarray: + """Apply the affine to an ``(N, 2)`` ``(x, y)`` coordinate array.""" + if coords.ndim != 2 or coords.shape[1] != 2: + raise ValueError(f"Expected an (N, 2) coordinate array, got shape {coords.shape}.") + return coords @ self.matrix[:2, :2].T + self.matrix[:2, 2] + + +@dataclass(frozen=True) +class ObsDisplacement: + """Per-obs ``(N, 2)`` ``(x, y)`` displacement field. + + Used by non-affine fits (e.g. LDDMM) where a single matrix cannot + represent the deformation. Displacements are added to the source + observation coordinates (also ``(x, y)``) to obtain the aligned + coordinates. + """ + + deltas: np.ndarray + source_cs: str | None = None + target_cs: str | None = None + + def __post_init__(self) -> None: + if self.deltas.ndim != 2 or self.deltas.shape[1] != 2: + raise ValueError(f"Expected an (N, 2) deltas array, got shape {self.deltas.shape}.") + + def apply(self, coords: np.ndarray) -> np.ndarray: + """Bake the displacement into an ``(N, 2)`` obs coordinate array.""" + if coords.shape != self.deltas.shape: + raise ValueError(f"Coord shape {coords.shape} does not match displacement shape {self.deltas.shape}.") + return coords + self.deltas + + +Transform = AffineTransform | ObsDisplacement + + +@dataclass(frozen=True) +class AlignResult: + """The output of an alignment backend call.""" + + transform: Transform + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def is_affine(self) -> bool: + return isinstance(self.transform, AffineTransform) diff --git a/src/squidpy/experimental/tl/_align/_validation.py b/src/squidpy/experimental/tl/_align/_validation.py new file mode 100644 index 000000000..e64c005c9 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_validation.py @@ -0,0 +1,143 @@ +"""Validation helpers shared by the public align_* functions. + +These wrap the generic checks in :mod:`squidpy._validators` with messages +tailored to alignment. The goal is to fail fast and tell the user *why* a +combination of arguments is wrong, not just that it is. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +import numpy as np + +from squidpy._validators import assert_one_of + +# Flavour identifiers - shared between dispatch, validation, and the public +# function defaults so a typo lights up everywhere it's used. +STALIGN = "stalign" +MOSCOT = "moscot" + +ALLOWED_FLAVOURS_OBS = (STALIGN, MOSCOT) +ALLOWED_FLAVOURS_IMAGES = (STALIGN,) +ALLOWED_OUTPUT_MODES_OBS = ("affine", "obs", "return") +ALLOWED_OUTPUT_MODES_NONOBS = ("affine", "return") +ALLOWED_LANDMARK_MODELS = ("similarity", "affine") + + +def validate_flavour(flavour: str, *, allowed: Sequence[str], op: str) -> None: + assert_one_of(flavour, allowed, name=f"{op}.flavour") + + +def validate_output_mode(output_mode: str, *, allowed: Sequence[str], op: str) -> None: + assert_one_of(output_mode, allowed, name=f"{op}.output_mode") + + +def validate_key_added(key_added: str | None, output_mode: str) -> None: + """``key_added`` only makes sense in the obs-materialisation path.""" + if key_added is not None and output_mode != "obs": + raise ValueError( + f"`key_added={key_added!r}` is only meaningful when `output_mode='obs'`. " + f"Got `output_mode={output_mode!r}`. The other modes either register a " + f"transformation on the existing element ('affine') or return the raw " + f"result ('return'), so there is nothing for `key_added` to name." + ) + + +def validate_landmark_model(model: str) -> None: + assert_one_of(model, ALLOWED_LANDMARK_MODELS, name="align_by_landmarks.model") + + +def validate_landmarks( + landmarks_ref: Sequence[tuple[float, float]], + landmarks_query: Sequence[tuple[float, float]], + *, + model: str, + cs_ref_extent: tuple[float, float, float, float] | None = None, + cs_query_extent: tuple[float, float, float, float] | None = None, +) -> tuple[np.ndarray, np.ndarray]: + """Validate landmark sequences and return them as ``(N, 2)`` arrays. + + Parameters + ---------- + landmarks_ref, landmarks_query + Sequences of ``(x, y)`` tuples. Must have the same length and at + least 3 entries (the closed-form solvers under both ``model`` + choices need at least 3 corresponding points). + model + ``"similarity"`` (4 DOF, via spatialdata) or ``"affine"`` (6 DOF, + via skimage's least-squares estimator). + cs_ref_extent, cs_query_extent + Optional ``(x_min, y_min, x_max, y_max)`` bounds of the named + coordinate system at the requested scale. When provided, every + landmark must fall inside. Catches the "I extracted these from + scale0 but asked for scale2" footgun. + """ + ref = np.asarray(landmarks_ref, dtype=float) + query = np.asarray(landmarks_query, dtype=float) + + if ref.ndim != 2 or ref.shape[1] != 2: + raise ValueError(f"`landmarks_ref` must be a sequence of (x, y) pairs, got shape {ref.shape}.") + if query.ndim != 2 or query.shape[1] != 2: + raise ValueError(f"`landmarks_query` must be a sequence of (x, y) pairs, got shape {query.shape}.") + if len(ref) != len(query): + raise ValueError( + f"`landmarks_ref` and `landmarks_query` must have the same length; got {len(ref)} and {len(query)}." + ) + + if len(ref) < 3: + raise ValueError( + f"`model={model!r}` needs at least 3 landmark pairs (spatialdata requirement), got {len(ref)}." + ) + + if cs_ref_extent is not None: + _check_in_extent(ref, cs_ref_extent, name="landmarks_ref") + if cs_query_extent is not None: + _check_in_extent(query, cs_query_extent, name="landmarks_query") + + return ref, query + + +def _check_in_extent( + points: np.ndarray, + extent: tuple[float, float, float, float], + *, + name: str, +) -> None: + x_min, y_min, x_max, y_max = extent + out_of_bounds = (points[:, 0] < x_min) | (points[:, 0] > x_max) | (points[:, 1] < y_min) | (points[:, 1] > y_max) + if out_of_bounds.any(): + bad = points[out_of_bounds] + raise ValueError( + f"{name}: {int(out_of_bounds.sum())} landmark(s) fall outside the coordinate-system " + f"extent (x in [{x_min}, {x_max}], y in [{y_min}, {y_max}]). " + f"This usually means the landmarks were extracted at a different scale than the " + f"one requested. First out-of-bounds point: {tuple(bad[0])}." + ) + + +def validate_unexpected( + *, + name: str, + value: Any, + when: str, + hint: str = "", +) -> None: + """Raise an educational error when an argument was passed in a context it has no role in.""" + if value is not None: + msg = f"`{name}={value!r}` was passed but is only valid when {when}." + if hint: + msg = f"{msg} {hint}" + raise ValueError(msg) + + +def validate_required( + *, + name: str, + value: Any, + when: str, +) -> None: + """Raise when an argument is required by the current context but missing.""" + if value is None: + raise ValueError(f"`{name}` is required when {when}.") diff --git a/tests/experimental/tl/__init__.py b/tests/experimental/tl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/experimental/tl/test_align_blobs_e2e.py b/tests/experimental/tl/test_align_blobs_e2e.py new file mode 100644 index 000000000..d7238ef81 --- /dev/null +++ b/tests/experimental/tl/test_align_blobs_e2e.py @@ -0,0 +1,216 @@ +"""End-to-end alignment tests on the spatialdata ``blobs`` fixture. + +Pulls the 200-point ``blobs_points`` cloud from +:func:`spatialdata.datasets.blobs`, applies a known transformation to a +copy, and verifies that both alignment paths recover the inverse: + +- ``align_by_landmarks`` should recover an exact closed-form solution. +- ``align_obs(flavour="stalign")`` should reduce the residual displacement + (the LDDMM is non-affine, so we compare the warped query against the ref + by mean Euclidean distance, not by exact equality). + +These tests exercise real solver iterations. They are slower than the +``niter=1`` smoke tests in ``test_align_stalign_integration.py`` (the +stalign test takes a few seconds) but still run on every commit. The user +can mark them ``slow`` later if CI budget becomes a concern. +""" + +from __future__ import annotations + +import numpy as np +import pytest +from anndata import AnnData + +pytest.importorskip("jax") + + +# --------------------------------------------------------------------------- +# Fixture +# --------------------------------------------------------------------------- + + +def _blobs_points_xy() -> np.ndarray: + """Pull the 200 ``blobs_points`` rows as an ``(N, 2)`` ``(x, y)`` array.""" + from spatialdata.datasets import blobs + + sd = blobs() + pts_df = sd.points["blobs_points"].compute() + return np.column_stack([pts_df["x"].to_numpy(), pts_df["y"].to_numpy()]).astype(float) + + +def _make_blob_adata(coords_xy: np.ndarray) -> AnnData: + """Wrap an ``(N, 2)`` point cloud as an AnnData with ``obsm['spatial']``.""" + adata = AnnData(np.zeros((coords_xy.shape[0], 1), dtype=float)) + adata.obsm["spatial"] = coords_xy + return adata + + +def _rotation_about_centre(theta_rad: float, centre_xy: np.ndarray) -> np.ndarray: + """Build a ``(3, 3)`` homogeneous rotation about a centre, in xy convention.""" + c, s = np.cos(theta_rad), np.sin(theta_rad) + R = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=float) + T_to = np.array([[1.0, 0.0, -centre_xy[0]], [0.0, 1.0, -centre_xy[1]], [0.0, 0.0, 1.0]]) + T_back = np.array([[1.0, 0.0, centre_xy[0]], [0.0, 1.0, centre_xy[1]], [0.0, 0.0, 1.0]]) + return T_back @ R @ T_to + + +def _apply_homog(points_xy: np.ndarray, M: np.ndarray) -> np.ndarray: + return points_xy @ M[:2, :2].T + M[:2, 2] + + +@pytest.fixture(scope="module") +def blobs_rotated() -> tuple[AnnData, AnnData, np.ndarray]: + """``(ref_adata, query_adata, gt_affine)`` for the rotation-recovery tests. + + The query is the reference rotated by 12° around the cloud centroid. + The ground-truth affine ``gt_affine`` maps ``ref -> query``; recovery + means producing a transform that maps ``query -> ref`` (i.e. the + inverse of ``gt_affine``). + """ + ref_xy = _blobs_points_xy() + centre = ref_xy.mean(axis=0) + gt = _rotation_about_centre(np.deg2rad(12.0), centre) + query_xy = _apply_homog(ref_xy, gt) + + ref = _make_blob_adata(ref_xy) + query = _make_blob_adata(query_xy) + return ref, query, gt + + +# --------------------------------------------------------------------------- +# 1. align_by_landmarks recovers the rotation exactly +# --------------------------------------------------------------------------- + + +def test_align_by_landmarks_recovers_blobs_rotation_exactly(blobs_rotated) -> None: + """A handful of correspondences is enough for the closed-form fit to + invert a pure 2D rotation up to numerical noise.""" + from squidpy.experimental.tl._align._backends._landmark import fit_landmark_affine + + ref, query, gt = blobs_rotated + + # Pick 4 landmarks that span the cloud well. + idx = [0, 50, 100, 150] + landmarks_ref = ref.obsm["spatial"][idx] + landmarks_query = query.obsm["spatial"][idx] + + fit = fit_landmark_affine(landmarks_ref, landmarks_query, model="similarity") + inv_gt = np.linalg.inv(gt) + + # Apply the recovered transform to all 200 query points and compare to ref. + recovered = fit.apply(query.obsm["spatial"]) + residual = np.linalg.norm(recovered - ref.obsm["spatial"], axis=1) + assert residual.max() < 1e-6, f"max residual {residual.max():.3e} should be ~0 for a rigid fit" + + # And the matrix itself should match the inverse of gt to high precision. + np.testing.assert_allclose(fit.matrix, inv_gt, atol=1e-9) + + +def test_align_by_landmarks_affine_recovers_blobs_rotation(blobs_rotated) -> None: + """The ``model='affine'`` path also fits a pure rotation correctly, + even though it has 2 extra DOF over the similarity case.""" + from squidpy.experimental.tl._align._backends._landmark import fit_landmark_affine + + ref, query, _gt = blobs_rotated + idx = [0, 30, 60, 90, 120, 150] + fit = fit_landmark_affine( + ref.obsm["spatial"][idx], + query.obsm["spatial"][idx], + model="affine", + ) + + recovered = fit.apply(query.obsm["spatial"]) + residual = np.linalg.norm(recovered - ref.obsm["spatial"], axis=1) + assert residual.max() < 1e-6 + + +# --------------------------------------------------------------------------- +# 2. align_obs (stalign) reduces the residual on the rotated cloud +# --------------------------------------------------------------------------- + + +def test_align_obs_stalign_reduces_residual_on_blobs(blobs_rotated) -> None: + """The LDDMM solver isn't expected to be exact - non-rigid by design - + but feeding it the landmark-fit affine as an init via the + ``landmarks_*`` kwargs should reduce the residual *below* the no-op + baseline by an order of magnitude. This is the wiring proof: we go + from raw misaligned coordinates to substantially-aligned coordinates + end-to-end through ``sq.experimental.tl.align_obs``. + """ + import squidpy as sq + + ref, query, _gt = blobs_rotated + + baseline = np.linalg.norm(ref.obsm["spatial"] - query.obsm["spatial"], axis=1).mean() + + config = sq.experimental.tl.STalignConfig( + preprocess=sq.experimental.tl.STalignPreprocessConfig(dx=20.0, blur=2.0, expand=1.2), + registration=sq.experimental.tl.STalignRegistrationConfig( + a=80.0, + expand=1.2, + nt=3, + niter=80, + epV=5e2, + ), + ) + + # Use the same well-spread landmarks as the closed-form test so the + # affine init is meaningful; LDDMM then refines the diffeomorphism. + idx = [0, 50, 100, 150] + landmarks_ref = ref.obsm["spatial"][idx] + landmarks_query = query.obsm["spatial"][idx] + + aligned = sq.experimental.tl.align_obs( + ref, + query, + flavour="stalign", + output_mode="obs", + inplace=False, + config=config, + landmarks_source=landmarks_query, + landmarks_target=landmarks_ref, + ) + assert isinstance(aligned, AnnData) + + after = np.linalg.norm(ref.obsm["spatial"] - aligned.obsm["spatial"], axis=1).mean() + # Sanity: the residual is finite and the alignment moved the points. + assert np.isfinite(after) + assert after < baseline, f"residual {after:.2f} should improve on baseline {baseline:.2f}" + + +# --------------------------------------------------------------------------- +# 3. The result type is consistent across both backends +# --------------------------------------------------------------------------- + + +def test_blobs_landmark_and_stalign_use_compatible_xy_convention(blobs_rotated) -> None: + """Both backends operate on (x, y) coords drawn from the same blobs + fixture and produce results in the same convention - sanity check that + the type unification didn't introduce a silent flip.""" + import squidpy as sq + from squidpy.experimental.tl._align._backends._landmark import fit_landmark_affine + from squidpy.experimental.tl._align._types import AffineTransform, ObsDisplacement + + ref, query, _ = blobs_rotated + idx = [0, 50, 100, 150] + affine_fit = fit_landmark_affine( + ref.obsm["spatial"][idx], + query.obsm["spatial"][idx], + model="similarity", + ) + assert isinstance(affine_fit, AffineTransform) + assert affine_fit.matrix.shape == (3, 3) + + config = sq.experimental.tl.STalignConfig( + preprocess=sq.experimental.tl.STalignPreprocessConfig(dx=30.0, blur=2.0), + registration=sq.experimental.tl.STalignRegistrationConfig(a=100.0, expand=1.2, nt=1, niter=1, epV=1.0), + ) + stalign_result = sq.experimental.tl.align_obs( + ref, + query, + flavour="stalign", + output_mode="return", + config=config, + ) + assert isinstance(stalign_result.transform, ObsDisplacement) + assert stalign_result.transform.deltas.shape == query.obsm["spatial"].shape diff --git a/tests/experimental/tl/test_align_skeleton.py b/tests/experimental/tl/test_align_skeleton.py new file mode 100644 index 000000000..620c1151d --- /dev/null +++ b/tests/experimental/tl/test_align_skeleton.py @@ -0,0 +1,489 @@ +"""Skeleton-level tests for ``squidpy.experimental.tl.align_*``. + +These tests verify *wiring* — argument resolution, dispatch, validation, and +lazy-import hygiene. Real solver tests come with the next PR that drops the +actual implementations into the prepared ``NotImplementedError`` slots. +""" + +from __future__ import annotations + +import sys + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from anndata import AnnData +from spatialdata import SpatialData +from spatialdata.models import Image2DModel, TableModel + +__all__: list[str] = [] # silence the import-only test module check + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_adata(n: int = 8, seed: int = 0) -> AnnData: + rng = np.random.default_rng(seed) + X = rng.standard_normal((n, 3)).astype(np.float32) + adata = AnnData(X=X) + adata.obs["region"] = "r" + adata.obs["instance_id"] = np.arange(n) + adata.obsm["spatial"] = rng.uniform(0, 100, (n, 2)) + return adata + + +def _make_table(name: str, n: int = 8, seed: int = 0) -> AnnData: + adata = _make_adata(n=n, seed=seed) + adata.obs["region"] = pd.Categorical([name] * n) + adata.uns["spatialdata_attrs"] = { + "region": name, + "region_key": "region", + "instance_key": "instance_id", + } + return TableModel.parse(adata) + + +def _make_sdata(image_keys=("img_ref", "img_query"), table_keys=("tbl_ref", "tbl_query")) -> SpatialData: + images = {} + for k in image_keys: + arr = np.zeros((3, 32, 32), dtype=np.uint8) + xa = xr.DataArray(arr, dims=["c", "y", "x"], coords={"c": ["R", "G", "B"]}) + images[k] = Image2DModel.parse(xa) + tables = {k: _make_table(k) for k in table_keys} + return SpatialData(images=images, tables=tables) + + +def _make_sdata_two_cs() -> SpatialData: + """Single SpatialData with two distinct coordinate systems (``cs_a``/``cs_b``). + + Used by the landmark tests where we want to align one cs to another + inside the same container. + """ + from spatialdata.transformations import Identity, set_transformation + + arr = np.zeros((3, 32, 32), dtype=np.uint8) + xa = xr.DataArray(arr, dims=["c", "y", "x"], coords={"c": ["R", "G", "B"]}) + img_a = Image2DModel.parse(xa) + img_b = Image2DModel.parse(xa.copy()) + set_transformation(img_a, Identity(), to_coordinate_system="cs_a") + set_transformation(img_b, Identity(), to_coordinate_system="cs_b") + return SpatialData(images={"img_a": img_a, "img_b": img_b}) + + +@pytest.fixture +def adata_pair() -> tuple[AnnData, AnnData]: + return _make_adata(seed=1), _make_adata(seed=2) + + +@pytest.fixture +def sdata_pair() -> tuple[SpatialData, SpatialData]: + return _make_sdata(), _make_sdata() + + +@pytest.fixture +def sdata_single() -> SpatialData: + return _make_sdata() + + +@pytest.fixture +def sdata_two_cs() -> SpatialData: + return _make_sdata_two_cs() + + +# --------------------------------------------------------------------------- +# 1. Public path +# --------------------------------------------------------------------------- + + +def test_public_callables_exist() -> None: + import squidpy as sq + + assert callable(sq.experimental.tl.align_obs) + assert callable(sq.experimental.tl.align_images) + assert callable(sq.experimental.tl.align_by_landmarks) + + +# --------------------------------------------------------------------------- +# 2. Lazy-import hygiene +# --------------------------------------------------------------------------- + + +def test_optional_deps_not_imported_at_import_time() -> None: + """Subprocess-isolated import to defeat module caching from other tests. + + A pop+reimport in-process is unreliable here because other tests in the + same session may have already pulled stalign/moscot/jax in transitively. + Spawn a clean Python and inspect its sys.modules. + """ + import subprocess + import sys as _sys + + out = subprocess.check_output( + [ + _sys.executable, + "-c", + ( + "import sys, squidpy; " + "leaked = [m for m in ('jax', 'stalign', 'moscot') if m in sys.modules]; " + "print(','.join(leaked))" + ), + ], + text=True, + ).strip() + assert out == "", f"Optional deps imported by `import squidpy`: {out}" + + +# --------------------------------------------------------------------------- +# 3. Resolver matrix for align_obs +# --------------------------------------------------------------------------- + + +def test_resolve_obs_pair_two_anndata(adata_pair) -> None: + from squidpy.experimental.tl._align._io import resolve_obs_pair + + a, b = adata_pair + pair = resolve_obs_pair(a, b, None, None) + assert pair.ref is a + assert pair.query is b + assert pair.ref_container is None + assert pair.query_container is None + + +def test_resolve_obs_pair_two_anndata_rejects_unneeded_name(adata_pair) -> None: + from squidpy.experimental.tl._align._io import resolve_obs_pair + + a, b = adata_pair + with pytest.raises(ValueError, match="adata_ref_name"): + resolve_obs_pair(a, b, "tbl_ref", None) + with pytest.raises(ValueError, match="adata_query_name"): + resolve_obs_pair(a, b, None, "tbl_query") + + +def test_resolve_obs_pair_two_sdata(sdata_pair) -> None: + from squidpy.experimental.tl._align._io import resolve_obs_pair + + sa, sb = sdata_pair + pair = resolve_obs_pair(sa, sb, "tbl_ref", "tbl_query") + assert pair.ref_container is sa + assert pair.query_container is sb + assert pair.ref_element_key == "tbl_ref" + assert pair.query_element_key == "tbl_query" + assert isinstance(pair.ref, AnnData) + assert isinstance(pair.query, AnnData) + + +def test_resolve_obs_pair_two_sdata_requires_names(sdata_pair) -> None: + from squidpy.experimental.tl._align._io import resolve_obs_pair + + sa, sb = sdata_pair + with pytest.raises(ValueError, match="adata_ref_name"): + resolve_obs_pair(sa, sb, None, "tbl_query") + with pytest.raises(ValueError, match="adata_query_name"): + resolve_obs_pair(sa, sb, "tbl_ref", None) + + +def test_resolve_obs_pair_single_sdata(sdata_single) -> None: + from squidpy.experimental.tl._align._io import resolve_obs_pair + + pair = resolve_obs_pair(sdata_single, None, "tbl_ref", "tbl_query") + assert pair.ref_container is sdata_single + assert pair.query_container is sdata_single + assert pair.ref_element_key == "tbl_ref" + assert pair.query_element_key == "tbl_query" + + +def test_resolve_obs_pair_single_sdata_same_name_passes_through(sdata_single) -> None: + """Same-name within one sdata is a valid no-op-ish call (identity fit). + + The resolver is not in the business of semantic-uniqueness validation; + backends are free to treat identical inputs however they like. + """ + from squidpy.experimental.tl._align._io import resolve_obs_pair + + pair = resolve_obs_pair(sdata_single, None, "tbl_ref", "tbl_ref") + assert pair.ref_element_key == "tbl_ref" + assert pair.query_element_key == "tbl_ref" + + +def test_resolve_obs_pair_mixed_inputs_rejected(adata_pair, sdata_single) -> None: + from squidpy.experimental.tl._align._io import resolve_obs_pair + + a, _ = adata_pair + with pytest.raises(TypeError, match="Mixed AnnData/SpatialData"): + resolve_obs_pair(a, sdata_single, None, "tbl_ref") + with pytest.raises(TypeError, match="Mixed AnnData/SpatialData"): + resolve_obs_pair(sdata_single, a, "tbl_ref", None) + + +def test_resolve_obs_pair_anndata_without_query(adata_pair) -> None: + from squidpy.experimental.tl._align._io import resolve_obs_pair + + a, _ = adata_pair + with pytest.raises(ValueError, match="`data_query` is required"): + resolve_obs_pair(a, None, None, None) + + +# --------------------------------------------------------------------------- +# 4. Multiscale image resolution +# --------------------------------------------------------------------------- + + +def test_resolve_image_pair_single_and_multiscale() -> None: + from squidpy.experimental.tl._align._io import resolve_image_pair + + # Single-scale image + arr = np.zeros((3, 32, 32), dtype=np.uint8) + xa = xr.DataArray(arr, dims=["c", "y", "x"], coords={"c": ["R", "G", "B"]}) + single = Image2DModel.parse(xa) + + # Multiscale image + multi = Image2DModel.parse(xa, scale_factors=[2]) + + sdata = SpatialData(images={"single": single, "multi": multi}) + + pair = resolve_image_pair(sdata, None, "single", "multi") + assert isinstance(pair.ref, xr.DataArray) + assert isinstance(pair.query, xr.DataArray) + assert pair.ref_element_key == "single" + assert pair.query_element_key == "multi" + + +def test_resolve_image_pair_same_name_passes_through() -> None: + """Same image name in a single sdata is a valid no-op call.""" + from squidpy.experimental.tl._align._io import resolve_image_pair + + arr = np.zeros((3, 16, 16), dtype=np.uint8) + xa = xr.DataArray(arr, dims=["c", "y", "x"], coords={"c": ["R", "G", "B"]}) + sdata = SpatialData(images={"img": Image2DModel.parse(xa)}) + + pair = resolve_image_pair(sdata, None, "img", "img") + assert pair.ref_element_key == "img" + assert pair.query_element_key == "img" + + +# --------------------------------------------------------------------------- +# 5. Landmark validation +# --------------------------------------------------------------------------- + + +def test_validate_landmarks_unequal_length() -> None: + from squidpy.experimental.tl._align._validation import validate_landmarks + + with pytest.raises(ValueError, match="same length"): + validate_landmarks(((0, 0), (1, 1)), ((0, 0),), model="similarity") + + +def test_validate_landmarks_requires_three_points() -> None: + """spatialdata's get_transformation_between_landmarks requires n>=3.""" + from squidpy.experimental.tl._align._validation import validate_landmarks + + with pytest.raises(ValueError, match="at least 3"): + validate_landmarks(((0, 0), (1, 1)), ((0, 0), (1, 1)), model="similarity") + + +def test_validate_landmarks_outside_extent() -> None: + from squidpy.experimental.tl._align._validation import validate_landmarks + + with pytest.raises(ValueError, match="outside the coordinate-system extent"): + validate_landmarks( + ((0, 0), (50, 50), (3, 3)), + ((0, 0), (5, 5), (1, 1)), + model="similarity", + cs_ref_extent=(0, 0, 10, 10), + ) + + +def test_validate_landmarks_happy_path() -> None: + from squidpy.experimental.tl._align._validation import validate_landmarks + + ref, query = validate_landmarks( + ((0, 0), (10, 0), (0, 10)), + ((1, 1), (11, 1), (1, 11)), + model="affine", + cs_ref_extent=(0, 0, 100, 100), + cs_query_extent=(0, 0, 100, 100), + ) + assert ref.shape == (3, 2) + assert query.shape == (3, 2) + + +# --------------------------------------------------------------------------- +# 6. Output-mode guards +# --------------------------------------------------------------------------- + + +def test_align_images_rejects_output_mode_obs(sdata_single) -> None: + import squidpy as sq + + with pytest.raises(ValueError, match="output_mode"): + sq.experimental.tl.align_images( + sdata_single, + None, + img_ref_name="img_ref", + img_query_name="img_query", + output_mode="obs", # type: ignore[arg-type] + ) + + +def test_key_added_only_with_obs_mode(sdata_single) -> None: + import squidpy as sq + + with pytest.raises(ValueError, match="key_added"): + sq.experimental.tl.align_obs( + sdata_single, + None, + adata_ref_name="tbl_ref", + adata_query_name="tbl_query", + output_mode="affine", + key_added="aligned", + ) + + +# --------------------------------------------------------------------------- +# 7. Dispatch +# --------------------------------------------------------------------------- + + +def test_align_obs_stalign_image_path_raises(sdata_single) -> None: + """``align_images(flavour='stalign')`` is still NotImplementedError; the + PR-#1150 lift only ships point alignment. This pins the contract that + the dispatch reaches the backend cleanly (no ImportError/AttributeError).""" + import squidpy as sq + + pytest.importorskip("jax") + with pytest.raises(NotImplementedError, match="stalign image alignment"): + sq.experimental.tl.align_images( + sdata_single, + None, + img_ref_name="img_ref", + img_query_name="img_query", + flavour="stalign", + ) + + +def test_align_obs_moscot_dispatch_reaches_backend(sdata_single) -> None: + import squidpy as sq + + pytest.importorskip("jax") + with pytest.raises(NotImplementedError, match="moscot backend"): + sq.experimental.tl.align_obs( + sdata_single, + None, + adata_ref_name="tbl_ref", + adata_query_name="tbl_query", + flavour="moscot", + ) + + +def test_align_obs_unknown_flavour(sdata_single) -> None: + import squidpy as sq + + with pytest.raises(ValueError, match="flavour"): + sq.experimental.tl.align_obs( + sdata_single, + None, + adata_ref_name="tbl_ref", + adata_query_name="tbl_query", + flavour="bogus", # type: ignore[arg-type] + ) + + +def test_align_images_rejects_moscot(sdata_single) -> None: + import squidpy as sq + + with pytest.raises(ValueError, match="flavour"): + sq.experimental.tl.align_images( + sdata_single, + None, + img_ref_name="img_ref", + img_query_name="img_query", + flavour="moscot", # type: ignore[arg-type] + ) + + +# --------------------------------------------------------------------------- +# 8. align_by_landmarks is JAX-free +# --------------------------------------------------------------------------- + + +def test_align_by_landmarks_two_cs_in_same_sdata(monkeypatch, sdata_two_cs) -> None: + """Align two distinct coordinate systems inside a single SpatialData via + the closed-form spatialdata fit. Also pins that the landmark path never + touches JAX: with ``jax`` blocked the call must still succeed, because + the landmark backend is pure NumPy/spatialdata. + """ + from spatialdata.transformations import Affine, get_transformation + + import squidpy as sq + + monkeypatch.setitem(sys.modules, "jax", None) + + # Three corresponding landmarks: identity translation by (+1, +2). + landmarks_ref = ((0.0, 0.0), (10.0, 0.0), (0.0, 10.0)) + landmarks_query = ((1.0, 2.0), (11.0, 2.0), (1.0, 12.0)) + + sq.experimental.tl.align_by_landmarks( + sdata_two_cs, + None, + cs_name_ref="cs_a", + cs_name_query="cs_b", + landmarks_ref=landmarks_ref, + landmarks_query=landmarks_query, + model="similarity", + ) + + # The fit should now have attached an affine on `img_b` mapping cs_b -> cs_a. + img_b = sdata_two_cs.images["img_b"] + transforms = get_transformation(img_b, get_all=True) + assert "cs_a" in transforms, f"alignment didn't register a cs_a transform on img_b; have {list(transforms)}" + assert isinstance(transforms["cs_a"], Affine) + + +def test_align_by_landmarks_affine_model(sdata_two_cs) -> None: + """The 6-DOF affine model fits via skimage and registers a transform.""" + from spatialdata.transformations import Affine, get_transformation + + import squidpy as sq + + # 4 landmarks (>3, since affine has 6 DOF and skimage wants over-determined input). + landmarks_ref = ((0.0, 0.0), (10.0, 0.0), (0.0, 10.0), (10.0, 10.0)) + landmarks_query = ((1.0, 2.0), (11.0, 2.0), (1.0, 12.0), (11.0, 12.0)) + + sq.experimental.tl.align_by_landmarks( + sdata_two_cs, + None, + cs_name_ref="cs_a", + cs_name_query="cs_b", + landmarks_ref=landmarks_ref, + landmarks_query=landmarks_query, + model="affine", + ) + + img_b = sdata_two_cs.images["img_b"] + transforms = get_transformation(img_b, get_all=True) + assert "cs_a" in transforms + assert isinstance(transforms["cs_a"], Affine) + + +# --------------------------------------------------------------------------- +# 9. JAX-required flavours fail cleanly without JAX +# --------------------------------------------------------------------------- + + +def test_stalign_without_jax_raises_importerror(monkeypatch, sdata_single) -> None: + import squidpy as sq + + # Block the import: the lazy `import jax` inside _jax.require_jax will hit None. + monkeypatch.setitem(sys.modules, "jax", None) + + with pytest.raises(ImportError, match="JAX is required"): + sq.experimental.tl.align_obs( + sdata_single, + None, + adata_ref_name="tbl_ref", + adata_query_name="tbl_query", + flavour="stalign", + ) diff --git a/tests/experimental/tl/test_align_stalign_integration.py b/tests/experimental/tl/test_align_stalign_integration.py new file mode 100644 index 000000000..40e4dc0fa --- /dev/null +++ b/tests/experimental/tl/test_align_stalign_integration.py @@ -0,0 +1,132 @@ +"""Happy-path integration tests for the stalign backend. + +These exercise the lift from scverse/squidpy#1150 through the +``align_obs`` API and the ``output_mode`` writeback paths. Tiny synthetic +fixtures with ``niter=1`` keep them fast enough to run on every commit; they +verify wiring and shapes only, **not** solver quality. Numeric / visual +end-to-end verification on a rotation-recovery fixture is a separate +follow-up (see the plan file). +""" + +from __future__ import annotations + +import numpy as np +import pytest +from anndata import AnnData + +pytest.importorskip("jax") + + +def _make_xy_adata() -> AnnData: + """Five-point synthetic AnnData with an ``obsm['spatial']`` cloud.""" + points_xy = np.array( + [ + [10.0, 1.0], + [12.0, 1.0], + [11.0, 2.0], + [10.0, 3.0], + [12.0, 3.0], + ] + ) + adata = AnnData(np.zeros((points_xy.shape[0], 1))) + adata.obsm["spatial"] = points_xy + return adata + + +def _tiny_config(): + """Single-iteration LDDMM hyperparameters - smallest possible solve.""" + import squidpy as sq + + return sq.experimental.tl.STalignConfig( + preprocess=sq.experimental.tl.STalignPreprocessConfig(dx=0.5, blur=1.0), + registration=sq.experimental.tl.STalignRegistrationConfig( + a=1.0, + expand=1.0, + nt=1, + niter=1, + epV=1.0, + ), + ) + + +def test_align_obs_stalign_return_mode_yields_obs_displacement() -> None: + """Wiring smoke test: dispatch -> stalign LDDMM -> ObsDisplacement.""" + import squidpy as sq + from squidpy.experimental.tl._align._types import AlignResult, ObsDisplacement + + ref = _make_xy_adata() + query = _make_xy_adata() + result = sq.experimental.tl.align_obs( + ref, + query, + flavour="stalign", + output_mode="return", + config=_tiny_config(), + ) + + assert isinstance(result, AlignResult) + assert isinstance(result.transform, ObsDisplacement) + assert result.transform.deltas.shape == query.obsm["spatial"].shape + assert np.all(np.isfinite(result.transform.deltas)) + assert result.metadata["flavour"] == "stalign" + # The escape hatch: the full STalignResult is preserved for power users. + assert "stalign_result" in result.metadata + + +def test_align_obs_stalign_obs_mode_writes_new_anndata() -> None: + """``output_mode='obs'`` materialises a new AnnData in the ref cs.""" + import squidpy as sq + + ref = _make_xy_adata() + query = _make_xy_adata() + aligned = sq.experimental.tl.align_obs( + ref, + query, + flavour="stalign", + output_mode="obs", + inplace=False, + config=_tiny_config(), + ) + + assert isinstance(aligned, AnnData) + assert aligned.obsm["spatial"].shape == query.obsm["spatial"].shape + # The writer stamps `align` metadata on uns so callers can introspect. + assert "align" in aligned.uns + assert aligned.uns["align"]["flavour"] == "stalign" + + +def test_align_obs_stalign_affine_mode_errors_for_non_affine_fit() -> None: + """LDDMM is non-affine; ``output_mode='affine'`` must error helpfully.""" + import squidpy as sq + + ref = _make_xy_adata() + query = _make_xy_adata() + with pytest.raises(TypeError, match="requires the backend to return an AffineTransform"): + sq.experimental.tl.align_obs( + ref, + query, + flavour="stalign", + output_mode="affine", + config=_tiny_config(), + ) + + +def test_align_obs_stalign_with_landmarks() -> None: + """Landmark-guided affine init reaches the solver via flavour_kwargs.""" + import squidpy as sq + + ref = _make_xy_adata() + query = _make_xy_adata() + landmarks_xy = ref.obsm["spatial"][:3] + + result = sq.experimental.tl.align_obs( + ref, + query, + flavour="stalign", + output_mode="return", + config=_tiny_config(), + landmarks_source=landmarks_xy, + landmarks_target=landmarks_xy, + ) + assert "stalign_result" in result.metadata + assert result.transform.deltas.shape == query.obsm["spatial"].shape From 837c25801dfb88643012adfea2af095cfb690f19 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 16 Apr 2026 12:42:25 +0200 Subject: [PATCH 2/6] Address review findings on alignment skeleton - Replace private `sdata._gen_elements()` with public `gen_elements()` - Replace dict-style `sdata[key]` lookup with explicit element-type search - Add subprocess timeout (30s) to lazy-import hygiene test - Document shallow X sharing in `materialise_obs` docstring - Document JAX array retention in stalign metadata comment - Document camelCase convention in STalignRegistrationConfig docstring - Broaden landmark type hints to accept Sequence and np.ndarray - Remove stale TODO comment from _stalign_helpers.py Co-Authored-By: Claude Opus 4.6 (1M context) --- src/squidpy/experimental/tl/_align/_api.py | 7 ++++-- .../tl/_align/_backends/_stalign.py | 8 +++--- .../tl/_align/_backends/_stalign_helpers.py | 4 --- .../tl/_align/_backends/_stalign_tools.py | 7 ++++++ src/squidpy/experimental/tl/_align/_io.py | 25 ++++++++++++++++--- tests/experimental/tl/test_align_skeleton.py | 1 + 6 files changed, 39 insertions(+), 13 deletions(-) diff --git a/src/squidpy/experimental/tl/_align/_api.py b/src/squidpy/experimental/tl/_align/_api.py index 4b08b7c69..d666a4a39 100644 --- a/src/squidpy/experimental/tl/_align/_api.py +++ b/src/squidpy/experimental/tl/_align/_api.py @@ -8,8 +8,11 @@ from __future__ import annotations +from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Literal +import numpy as np + from squidpy.experimental.tl._align._backends import get_backend from squidpy.experimental.tl._align._io import ( apply_affine_to_cs, @@ -160,8 +163,8 @@ def align_by_landmarks( cs_name_query: str | None = None, scale_ref: str | None = None, scale_query: str | None = None, - landmarks_ref: tuple[tuple[float, float], ...] | None = None, - landmarks_query: tuple[tuple[float, float], ...] | None = None, + landmarks_ref: Sequence[tuple[float, float]] | np.ndarray | None = None, + landmarks_query: Sequence[tuple[float, float]] | np.ndarray | None = None, *, model: Literal["similarity", "affine"] = "similarity", output_mode: Literal["affine", "return"] = "affine", diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign.py b/src/squidpy/experimental/tl/_align/_backends/_stalign.py index 6d7c25d8c..1115c48d2 100644 --- a/src/squidpy/experimental/tl/_align/_backends/_stalign.py +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign.py @@ -82,9 +82,11 @@ def align_obs( ), metadata={ "flavour": "stalign", - # Escape hatch for power users who want the diffeomorphic - # part (velocity field, velocity grid, affine init) rather - # than just the materialised displacement. + # Escape hatch: the full STalignResult (velocity field, + # velocity grid, affine init) for power users who need + # the diffeomorphic map. This keeps the JAX arrays alive + # in memory -- callers who only need the displacement + # should drop this key or use ``output_mode='obs'``. "stalign_result": stalign_result, }, ) diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign_helpers.py b/src/squidpy/experimental/tl/_align/_backends/_stalign_helpers.py index 7e5e5099e..e85212fda 100644 --- a/src/squidpy/experimental/tl/_align/_backends/_stalign_helpers.py +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign_helpers.py @@ -56,10 +56,6 @@ def extract_landmarks(adata: AnnData, key: str) -> np.ndarray: raise KeyError(f"Key `{key}` not found in `adata.obsm` or `adata.uns`.") -# TODO: are these duplicated? I would imagine its -# better to keep image transform functions under some place - - def to_row_col(points: np.ndarray, *, point_order: PointOrder) -> np.ndarray: """Convert coordinates to row-column order.""" arr = _validate_points(points, name="points") diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py b/src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py index c2ba666d6..84365f445 100644 --- a/src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py @@ -54,6 +54,13 @@ class STalignPreprocessConfig: @dataclass(slots=True) class STalignRegistrationConfig: + """LDDMM registration hyperparameters. + + Field names (``sigmaM``, ``epL``, etc.) preserve the conventions from + the STalign paper and reference implementation to keep them + recognisable when cross-referencing the literature. + """ + a: float = 500.0 p: float = 2.0 expand: float = 2.0 diff --git a/src/squidpy/experimental/tl/_align/_io.py b/src/squidpy/experimental/tl/_align/_io.py index 3a98d2e98..dd6c83a92 100644 --- a/src/squidpy/experimental/tl/_align/_io.py +++ b/src/squidpy/experimental/tl/_align/_io.py @@ -213,7 +213,7 @@ def apply_affine_to_cs( if pair.query_container is not None and pair.query_element_key is not None: sdata = pair.query_container if inplace else _shallow_copy_sdata(pair.query_container) - element = sdata[pair.query_element_key] + element = _get_element(sdata, pair.query_element_key) set_transformation(element, affine.to_spatialdata(), to_coordinate_system=target_cs) return None if inplace else sdata @@ -222,7 +222,9 @@ def apply_affine_to_cs( moving_cs = pair.query_cs sd_affine = affine.to_spatialdata() touched_any = False - for _etype, _name, element in sdata._gen_elements(include_tables=False): + for _etype, _name, element in sdata.gen_elements(): + if isinstance(element, AnnData): + continue element_transforms = get_transformation(element, get_all=True) if moving_cs not in element_transforms: continue @@ -258,6 +260,13 @@ def materialise_obs( add the deltas. When the source query lives inside a SpatialData, the new AnnData is registered as ``sdata.tables[key_added]``; otherwise it is returned directly. + + .. note:: + + The returned AnnData **shares** ``X`` and ``var`` with the source + query by reference to avoid copying potentially-large expression + matrices. Mutating one will affect the other. Call + ``.copy()`` on the result if you need full independence. """ if not isinstance(pair.query, AnnData): raise TypeError("materialise_obs only works for `align_obs`; `pair.query` must be an AnnData.") @@ -297,6 +306,15 @@ def materialise_obs( return new_adata +def _get_element(sdata: SpatialData, key: str) -> object: + """Look up a spatial element by name across all element types.""" + for attr in ("images", "labels", "points", "shapes", "tables"): + store = getattr(sdata, attr) + if key in store: + return store[key] + raise KeyError(f"Element {key!r} not found in the SpatialData object.") + + def _shallow_copy_sdata(sdata: SpatialData) -> SpatialData: """Shallow copy of a SpatialData object for ``inplace=False`` writeback paths. @@ -304,6 +322,5 @@ def _shallow_copy_sdata(sdata: SpatialData) -> SpatialData: ``attrs`` propagate the same way spatialdata's own subsetting handles them, rather than reconstructing via the ``__init__`` constructor. """ - element_names = [name for _, name, _ in sdata._gen_elements(include_tables=True)] + element_names = [name for _, name, _ in sdata.gen_elements()] return sdata.subset(element_names, filter_tables=False, include_orphan_tables=True) - diff --git a/tests/experimental/tl/test_align_skeleton.py b/tests/experimental/tl/test_align_skeleton.py index 620c1151d..3aa9a7808 100644 --- a/tests/experimental/tl/test_align_skeleton.py +++ b/tests/experimental/tl/test_align_skeleton.py @@ -132,6 +132,7 @@ def test_optional_deps_not_imported_at_import_time() -> None: ), ], text=True, + timeout=30, ).strip() assert out == "", f"Optional deps imported by `import squidpy`: {out}" From a320ad14eead6ba1ecdd4a540e54d66613d09bcb Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 16 Apr 2026 12:52:02 +0200 Subject: [PATCH 3/6] Simplify alignment skeleton after review - Resolve JAX_DTYPE lazily via jax_dtype() to respect runtime x64 config - Replace 14-line config field unpack with dataclasses.asdict(registration) - Remove unreachable ValueError in _writeback (already validated upstream) - Remove task-tracking comment from moscot stub - Clean up PR-line-number reference in stalign comment Co-Authored-By: Claude Opus 4.6 (1M context) --- src/squidpy/experimental/tl/_align/_api.py | 7 ++-- .../tl/_align/_backends/_moscot.py | 2 -- .../tl/_align/_backends/_stalign.py | 2 +- .../tl/_align/_backends/_stalign_core.py | 26 ++++++++------ .../tl/_align/_backends/_stalign_tools.py | 35 ++++++------------- 5 files changed, 30 insertions(+), 42 deletions(-) diff --git a/src/squidpy/experimental/tl/_align/_api.py b/src/squidpy/experimental/tl/_align/_api.py index d666a4a39..b898786a3 100644 --- a/src/squidpy/experimental/tl/_align/_api.py +++ b/src/squidpy/experimental/tl/_align/_api.py @@ -248,7 +248,6 @@ def _writeback( ) return apply_affine_to_cs(pair, result.transform, inplace=inplace) - if output_mode == "obs": - return materialise_obs(pair, result, key_added=key_added, inplace=inplace) - - raise ValueError(f"Unknown output_mode {output_mode!r}.") + # output_mode == "obs" -- the only remaining valid branch after + # validate_output_mode has already rejected unknown values. + return materialise_obs(pair, result, key_added=key_added, inplace=inplace) diff --git a/src/squidpy/experimental/tl/_align/_backends/_moscot.py b/src/squidpy/experimental/tl/_align/_backends/_moscot.py index b5fe01f64..6b5e98e5e 100644 --- a/src/squidpy/experimental/tl/_align/_backends/_moscot.py +++ b/src/squidpy/experimental/tl/_align/_backends/_moscot.py @@ -28,8 +28,6 @@ def align_obs( from squidpy.experimental.tl._align._jax import require_jax require_jax(device) - # Lazy moscot import lives here in the next PR: - # import moscot raise NotImplementedError( "moscot backend `align_obs`: TODO. Skeleton landed; the moscot " "solver will replace this body in a follow-up PR." diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign.py b/src/squidpy/experimental/tl/_align/_backends/_stalign.py index 1115c48d2..71605ec52 100644 --- a/src/squidpy/experimental/tl/_align/_backends/_stalign.py +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign.py @@ -56,7 +56,7 @@ def align_obs( tgt_xy = np.asarray(pair.ref.obsm["spatial"], dtype=float) # stalign_points runs internally in row_col (yx); obsm["spatial"] is xy - # by squidpy convention. Mirror Selman's _stalign.py:69-70 / :90 swap. + # by squidpy convention -- swap axes at the boundary. src_rc = src_xy[:, [1, 0]] tgt_rc = tgt_xy[:, [1, 0]] landmarks_src_rc = None if landmarks_source is None else np.asarray(landmarks_source)[:, [1, 0]] diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign_core.py b/src/squidpy/experimental/tl/_align/_backends/_stalign_core.py index 65e8ce72c..8ac139b6c 100644 --- a/src/squidpy/experimental/tl/_align/_backends/_stalign_core.py +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign_core.py @@ -12,8 +12,12 @@ import jax.scipy as jsp import numpy as np -JAX_DTYPE = jnp.float64 if jax.config.x64_enabled else jnp.float32 -__all__ = ["JAX_DTYPE", "lddmm", "transform_points_row_col"] +__all__ = ["jax_dtype", "lddmm", "transform_points_row_col"] + + +def jax_dtype() -> jnp.dtype: + """Resolve the active JAX float dtype at call time, not import time.""" + return jnp.float64 if jax.config.x64_enabled else jnp.float32 def _to_affine(linear: Any, translation: Any) -> Any: @@ -279,20 +283,20 @@ def lddmm( ) -> dict[str, Any]: x_source = (jnp.asarray(xI[0]), jnp.asarray(xI[1])) x_target = (jnp.asarray(xJ[0]), jnp.asarray(xJ[1])) - source_image = jnp.asarray(I, dtype=JAX_DTYPE) - target_image = jnp.asarray(J, dtype=JAX_DTYPE) - linear = jnp.asarray(L, dtype=JAX_DTYPE) - translation = jnp.asarray(T, dtype=JAX_DTYPE) + source_image = jnp.asarray(I, dtype=jax_dtype()) + target_image = jnp.asarray(J, dtype=jax_dtype()) + linear = jnp.asarray(L, dtype=jax_dtype()) + translation = jnp.asarray(T, dtype=jax_dtype()) if points_source is None: - source_landmarks = jnp.zeros((0, 2), dtype=JAX_DTYPE) - target_landmarks = jnp.zeros((0, 2), dtype=JAX_DTYPE) + source_landmarks = jnp.zeros((0, 2), dtype=jax_dtype()) + target_landmarks = jnp.zeros((0, 2), dtype=jax_dtype()) else: - source_landmarks = jnp.asarray(points_source, dtype=JAX_DTYPE) - target_landmarks = jnp.asarray(points_target, dtype=JAX_DTYPE) + source_landmarks = jnp.asarray(points_source, dtype=jax_dtype()) + target_landmarks = jnp.asarray(points_target, dtype=jax_dtype()) xv = _build_velocity_grid(x_source, a=a, expand=expand) - velocity = jnp.zeros((nt, xv[0].shape[0], xv[1].shape[0], 2), dtype=JAX_DTYPE) + velocity = jnp.zeros((nt, xv[0].shape[0], xv[1].shape[0], 2), dtype=jax_dtype()) kernel, ll, dv_prod = _build_regularizer(xv, a=a, p=p) match_weights = jnp.full(target_image.shape[1:], 0.5, dtype=target_image.dtype) diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py b/src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py index 84365f445..2098466ab 100644 --- a/src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py @@ -1,20 +1,19 @@ """Low-level point-cloud tools for experimental STalign. -Lifted from scverse/squidpy#1150 (Selman Özleyen). Only the two import paths -below were rewritten to point at the sibling lifted modules; the rest of the -file is byte-for-byte identical to the upstream PR. +Lifted from scverse/squidpy#1150 (Selman Özleyen) with import paths +adjusted and minor cleanups (config unpacking, lazy dtype resolution). """ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from typing import TYPE_CHECKING, Any, Literal, TypeAlias import jax.numpy as jnp import numpy as np from anndata import AnnData -from squidpy.experimental.tl._align._backends._stalign_core import JAX_DTYPE, lddmm, transform_points_row_col +from squidpy.experimental.tl._align._backends._stalign_core import jax_dtype, lddmm, transform_points_row_col from squidpy.experimental.tl._align._backends._stalign_helpers import ( PointOrder, affine_from_points, @@ -191,7 +190,7 @@ def transform_points( xv, jnp.asarray(v), jnp.asarray(A), - jnp.asarray(points_rc, dtype=JAX_DTYPE), + jnp.asarray(points_rc, dtype=jax_dtype()), direction=direction, ) return jnp.asarray(from_row_col(np.asarray(transformed), point_order=point_order)) @@ -227,29 +226,17 @@ def stalign_points( target_landmarks = to_row_col(landmarks_target, point_order="row_col") linear, translation = affine_from_points(source_landmarks, target_landmarks) + dtype = jax_dtype() result = lddmm( preprocessed.source_grid, preprocessed.source_image, preprocessed.target_grid, preprocessed.target_image, - L=jnp.asarray(linear, dtype=JAX_DTYPE), - T=jnp.asarray(translation, dtype=JAX_DTYPE), - points_source=None if source_landmarks is None else jnp.asarray(source_landmarks, dtype=JAX_DTYPE), - points_target=None if target_landmarks is None else jnp.asarray(target_landmarks, dtype=JAX_DTYPE), - a=registration.a, - p=registration.p, - expand=registration.expand, - nt=registration.nt, - niter=registration.niter, - diffeo_start=registration.diffeo_start, - epL=registration.epL, - epT=registration.epT, - epV=registration.epV, - sigmaM=registration.sigmaM, - sigmaB=registration.sigmaB, - sigmaA=registration.sigmaA, - sigmaR=registration.sigmaR, - sigmaP=registration.sigmaP, + L=jnp.asarray(linear, dtype=dtype), + T=jnp.asarray(translation, dtype=dtype), + points_source=None if source_landmarks is None else jnp.asarray(source_landmarks, dtype=dtype), + points_target=None if target_landmarks is None else jnp.asarray(target_landmarks, dtype=dtype), + **asdict(registration), ) aligned_points = transform_points( result["xv"], From bc6720d066615588d413beb7f9e631f5378f0c6e Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 16 Apr 2026 12:57:14 +0200 Subject: [PATCH 4/6] Add JAX test environment to CI matrix Adds a hatch-test.py3.13-stable-jax environment that installs the [jax] optional extra so the STalign solver and e2e alignment tests run in CI. Excluded from macOS to avoid doubling runner cost. Co-Authored-By: Claude Opus 4.6 (1M context) --- .github/workflows/test.yaml | 2 ++ hatch.toml | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index c75993038..386e85aa3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -104,6 +104,8 @@ jobs: env: { name: hatch-test.py3.12-stable } - os: macos-latest env: { name: hatch-test.py3.13-pre } # pre-release only needs one OS + - os: macos-latest + env: { name: hatch-test.py3.13-stable-jax } # JAX only needs one OS - os: ubuntu-latest env: { name: hatch-test.py3.13-stable } # skipping because we run this as a coverage job name: ${{ matrix.env.label }} (${{ matrix.os }}) diff --git a/hatch.toml b/hatch.toml index 08afad3b9..b870ce00d 100644 --- a/hatch.toml +++ b/hatch.toml @@ -39,6 +39,11 @@ download = "python ./.scripts/ci/download_data.py {args}" deps = ["stable"] python = ["3.11", "3.12", "3.13"] +[[envs.hatch-test.matrix]] +deps = ["stable"] +python = ["3.13"] +extras = ["jax"] + [[envs.hatch-test.matrix]] deps = ["pre"] python = ["3.13"] @@ -47,6 +52,9 @@ python = ["3.13"] matrix.deps.env-vars = [ { key = "UV_PRERELEASE", value = "allow", if = ["pre"] }, ] +matrix.extras.features = [ + { value = "jax", if = ["jax"] }, +] [envs.notebooks] extra-dependencies = [ From 93ed71ceb05b81ad01b4369e66f0c603c9077f45 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 16 Apr 2026 14:13:26 +0200 Subject: [PATCH 5/6] Improve alignment API defaults and public surface - align_obs: default output_mode="obs" (was "affine", which crashed with the default stalign backend). Auto-generate key_added from query name when not provided for SpatialData inputs. - align_by_landmarks: make cs_name_ref/query and landmarks_ref/query keyword-only required args (were Optional with None defaults that immediately errored). Remove unused scale_ref/scale_query params. Wire get_extent validation for landmark bounds checking against cs extent. Fix docstring (y, x) -> (x, y). - align_images: make img_ref/query_name keyword-only required. Remove from public __all__ (no backend implements it yet). - align_obs docstring: note that inplace only affects SpatialData. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/squidpy/experimental/tl/__init__.py | 3 +- .../experimental/tl/_align/__init__.py | 10 +- src/squidpy/experimental/tl/_align/_api.py | 93 ++++++++++++------- tests/experimental/tl/test_align_skeleton.py | 15 +-- 4 files changed, 70 insertions(+), 51 deletions(-) diff --git a/src/squidpy/experimental/tl/__init__.py b/src/squidpy/experimental/tl/__init__.py index 2f6fbd0b5..a9bb596d9 100644 --- a/src/squidpy/experimental/tl/__init__.py +++ b/src/squidpy/experimental/tl/__init__.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any -from squidpy.experimental.tl._align import align_by_landmarks, align_images, align_obs +from squidpy.experimental.tl._align import align_by_landmarks, align_obs if TYPE_CHECKING: from squidpy.experimental.tl._align._backends._stalign_tools import ( @@ -20,7 +20,6 @@ "STalignRegistrationConfig", "STalignResult", "align_by_landmarks", - "align_images", "align_obs", ] diff --git a/src/squidpy/experimental/tl/_align/__init__.py b/src/squidpy/experimental/tl/_align/__init__.py index b070b99c5..51f9cc617 100644 --- a/src/squidpy/experimental/tl/_align/__init__.py +++ b/src/squidpy/experimental/tl/_align/__init__.py @@ -2,11 +2,10 @@ Public surface: -- :func:`align_obs` — align two ``obs``-level point clouds (cells / spots). -- :func:`align_images` — align two raster images in :class:`spatialdata.SpatialData`. -- :func:`align_by_landmarks` — closed-form fit from user-provided landmarks. +- :func:`align_obs` -- align two ``obs``-level point clouds (cells / spots). +- :func:`align_by_landmarks` -- closed-form fit from user-provided landmarks. -Optional backends (``stalign``, ``moscot``) and JAX are imported lazily — only +Optional backends (``stalign``, ``moscot``) and JAX are imported lazily -- only the function call that needs them pulls them in. """ @@ -14,8 +13,7 @@ from squidpy.experimental.tl._align._api import ( align_by_landmarks, - align_images, align_obs, ) -__all__ = ["align_by_landmarks", "align_images", "align_obs"] +__all__ = ["align_by_landmarks", "align_obs"] diff --git a/src/squidpy/experimental/tl/_align/_api.py b/src/squidpy/experimental/tl/_align/_api.py index b898786a3..5e3b555c4 100644 --- a/src/squidpy/experimental/tl/_align/_api.py +++ b/src/squidpy/experimental/tl/_align/_api.py @@ -32,7 +32,6 @@ validate_landmark_model, validate_landmarks, validate_output_mode, - validate_required, ) if TYPE_CHECKING: @@ -47,7 +46,7 @@ def align_obs( adata_query_name: str | None = None, flavour: Literal["stalign", "moscot"] = "stalign", *, - output_mode: Literal["affine", "obs", "return"] = "affine", + output_mode: Literal["affine", "obs", "return"] = "obs", key_added: str | None = None, device: Literal["cpu", "gpu"] | None = None, inplace: bool = True, @@ -71,23 +70,26 @@ def align_obs( output_mode How to deliver the result: - - ``'affine'`` — register the fitted affine on the query element via + - ``'obs'`` (default) -- bake the fit into a new AnnData whose + ``obsm['spatial']`` already lives in the reference coordinate + system. For SpatialData inputs the new table is stored under + ``key_added`` (auto-generated from the query name when omitted). + - ``'affine'`` -- register the fitted affine on the query element via :func:`spatialdata.transformations.set_transformation`, so every element in the query coordinate system inherits the alignment. Requires the backend to produce an affine transform. - - ``'obs'`` — bake the (possibly non-affine) fit into a new AnnData - whose ``obsm['spatial']`` already lives in the reference coordinate - system; for SpatialData inputs the new table is stored under - ``key_added``. - - ``'return'`` — return the raw :class:`AlignResult`; no writeback. + - ``'return'`` -- return the raw :class:`AlignResult`; no writeback. key_added - Required when ``output_mode='obs'`` and inputs are SpatialData. + Name for the aligned table when ``output_mode='obs'`` and inputs are + SpatialData. Defaults to ``'{adata_query_name}_aligned'``. Rejected with any other ``output_mode``. device ``'cpu'``/``'gpu'`` to force a JAX device, or ``None`` to let JAX pick the default. Only consulted by JAX-backed flavours. inplace If ``True``, mutate the query container; otherwise return a copy. + Only affects SpatialData inputs -- for plain AnnData with + ``output_mode='obs'``, the aligned AnnData is always returned. **flavour_kwargs Backend-specific knobs forwarded as-is to the chosen backend. """ @@ -99,16 +101,20 @@ def align_obs( backend = get_backend(flavour) result = backend.align_obs(pair, device=device, **flavour_kwargs) + # Auto-generate key_added for SpatialData obs writeback. + if key_added is None and output_mode == "obs" and pair.query_element_key is not None: + key_added = f"{pair.query_element_key}_aligned" + return _writeback(pair, result, output_mode=output_mode, key_added=key_added, inplace=inplace) def align_images( sdata_ref: SpatialData, sdata_query: SpatialData | None = None, - img_ref_name: str | None = None, - img_query_name: str | None = None, - flavour: Literal["stalign"] = "stalign", *, + img_ref_name: str, + img_query_name: str, + flavour: Literal["stalign"] = "stalign", scale_ref: str | Literal["auto"] = "auto", scale_query: str | Literal["auto"] = "auto", output_mode: Literal["affine", "return"] = "affine", @@ -118,6 +124,11 @@ def align_images( ) -> SpatialData | AlignResult | None: """Align two raster images living inside :class:`spatialdata.SpatialData`. + .. note:: + + No backend currently implements image alignment. This function is + reserved for a follow-up PR and is not yet part of the public API. + Parameters ---------- sdata_ref, sdata_query @@ -137,8 +148,6 @@ def align_images( device, inplace, flavour_kwargs See :func:`align_obs`. """ - validate_required(name="img_ref_name", value=img_ref_name, when="calling `align_images`") - validate_required(name="img_query_name", value=img_query_name, when="calling `align_images`") validate_flavour(flavour, allowed=ALLOWED_FLAVOURS_IMAGES, op="align_images") validate_output_mode(output_mode, allowed=ALLOWED_OUTPUT_MODES_NONOBS, op="align_images") @@ -159,20 +168,18 @@ def align_images( def align_by_landmarks( sdata_ref: SpatialData, sdata_query: SpatialData | None = None, - cs_name_ref: str | None = None, - cs_name_query: str | None = None, - scale_ref: str | None = None, - scale_query: str | None = None, - landmarks_ref: Sequence[tuple[float, float]] | np.ndarray | None = None, - landmarks_query: Sequence[tuple[float, float]] | np.ndarray | None = None, *, + cs_name_ref: str, + cs_name_query: str, + landmarks_ref: Sequence[tuple[float, float]] | np.ndarray, + landmarks_query: Sequence[tuple[float, float]] | np.ndarray, model: Literal["similarity", "affine"] = "similarity", output_mode: Literal["affine", "return"] = "affine", inplace: bool = True, ) -> SpatialData | AlignResult | None: """Align by a closed-form fit on user-provided landmarks. - Pure NumPy under the hood — JAX is **not** required for this path. + Pure NumPy under the hood -- JAX is **not** required for this path. Parameters ---------- @@ -181,31 +188,34 @@ def align_by_landmarks( coordinate systems of the same SpatialData against each other. cs_name_ref, cs_name_query Coordinate system names. - scale_ref, scale_query - Optional scale identifiers used purely for landmark-extent - validation: if you extracted your landmarks at a particular scale, - passing the same scale here lets us catch the "wrong scale" footgun - early. landmarks_ref, landmarks_query - Equal-length sequences of ``(y, x)`` tuples. ``model='similarity'`` - needs ≥ 2 pairs, ``model='affine'`` needs ≥ 3. + Equal-length sequences of ``(x, y)`` tuples in the pixel space of + the respective coordinate system. ``model='similarity'`` needs + at least 3 pairs, ``model='affine'`` needs at least 3. + Landmarks are validated against the coordinate-system extent + (via :func:`spatialdata.get_extent`) to catch scale mismatches + early. model ``'similarity'`` (rotation + uniform scale + translation) or ``'affine'`` (full 6-parameter linear). output_mode, inplace See :func:`align_obs`. """ - validate_required(name="cs_name_ref", value=cs_name_ref, when="calling `align_by_landmarks`") - validate_required(name="cs_name_query", value=cs_name_query, when="calling `align_by_landmarks`") - validate_required(name="landmarks_ref", value=landmarks_ref, when="calling `align_by_landmarks`") - validate_required(name="landmarks_query", value=landmarks_query, when="calling `align_by_landmarks`") - validate_output_mode(output_mode, allowed=ALLOWED_OUTPUT_MODES_NONOBS, op="align_by_landmarks") validate_landmark_model(model) - # We don't materialise extents here in the skeleton; backends / a future - # PR can fill in the cs-extent lookup once we wire spatialdata.get_extent. - ref_arr, query_arr = validate_landmarks(landmarks_ref, landmarks_query, model=model) + # Fetch coordinate-system extents for landmark bounds checking. + cs_ref_extent = _get_cs_extent(sdata_ref, cs_name_ref) + sdata_query_resolved = sdata_query if sdata_query is not None else sdata_ref + cs_query_extent = _get_cs_extent(sdata_query_resolved, cs_name_query) + + ref_arr, query_arr = validate_landmarks( + landmarks_ref, + landmarks_query, + model=model, + cs_ref_extent=cs_ref_extent, + cs_query_extent=cs_query_extent, + ) pair = resolve_element_pair(sdata_ref, sdata_query, cs_name_ref, cs_name_query) @@ -223,6 +233,17 @@ def align_by_landmarks( return _writeback(pair, result, output_mode=output_mode, key_added=None, inplace=inplace) +def _get_cs_extent( + sdata: SpatialData, + cs_name: str, +) -> tuple[float, float, float, float]: + """Return ``(x_min, y_min, x_max, y_max)`` for a coordinate system.""" + from spatialdata import get_extent + + extent = get_extent(sdata, coordinate_system=cs_name) + return (extent["x"][0], extent["y"][0], extent["x"][1], extent["y"][1]) + + # --------------------------------------------------------------------------- # Internal: writeback dispatch # --------------------------------------------------------------------------- diff --git a/tests/experimental/tl/test_align_skeleton.py b/tests/experimental/tl/test_align_skeleton.py index 3aa9a7808..b6ff07528 100644 --- a/tests/experimental/tl/test_align_skeleton.py +++ b/tests/experimental/tl/test_align_skeleton.py @@ -102,8 +102,9 @@ def test_public_callables_exist() -> None: import squidpy as sq assert callable(sq.experimental.tl.align_obs) - assert callable(sq.experimental.tl.align_images) assert callable(sq.experimental.tl.align_by_landmarks) + # align_images is not yet public (no backend implements it) + assert not hasattr(sq.experimental.tl, "align_images") # --------------------------------------------------------------------------- @@ -317,10 +318,10 @@ def test_validate_landmarks_happy_path() -> None: def test_align_images_rejects_output_mode_obs(sdata_single) -> None: - import squidpy as sq + from squidpy.experimental.tl._align._api import align_images with pytest.raises(ValueError, match="output_mode"): - sq.experimental.tl.align_images( + align_images( sdata_single, None, img_ref_name="img_ref", @@ -352,11 +353,11 @@ def test_align_obs_stalign_image_path_raises(sdata_single) -> None: """``align_images(flavour='stalign')`` is still NotImplementedError; the PR-#1150 lift only ships point alignment. This pins the contract that the dispatch reaches the backend cleanly (no ImportError/AttributeError).""" - import squidpy as sq + from squidpy.experimental.tl._align._api import align_images pytest.importorskip("jax") with pytest.raises(NotImplementedError, match="stalign image alignment"): - sq.experimental.tl.align_images( + align_images( sdata_single, None, img_ref_name="img_ref", @@ -393,10 +394,10 @@ def test_align_obs_unknown_flavour(sdata_single) -> None: def test_align_images_rejects_moscot(sdata_single) -> None: - import squidpy as sq + from squidpy.experimental.tl._align._api import align_images with pytest.raises(ValueError, match="flavour"): - sq.experimental.tl.align_images( + align_images( sdata_single, None, img_ref_name="img_ref", From 1d4c1350e4bed823aabd0ab03833296cddade50f Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 16 Apr 2026 14:21:18 +0200 Subject: [PATCH 6/6] Remove device parameter, let JAX handle device selection JAX selects the appropriate device based on its install (CPU/GPU) and runtime context managers. The explicit device arg added unnecessary complexity with no benefit over JAX's built-in device management. Removes device from: align_obs, align_images, AlignBackend protocol, StAlignBackend, MoscotBackend, and require_jax. Simplifies require_jax to a pure import guard. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/squidpy/experimental/tl/_align/_api.py | 11 ++--- .../experimental/tl/_align/_backends/_base.py | 6 +-- .../tl/_align/_backends/_moscot.py | 8 +--- .../tl/_align/_backends/_stalign.py | 9 ++--- src/squidpy/experimental/tl/_align/_jax.py | 40 +++---------------- 5 files changed, 15 insertions(+), 59 deletions(-) diff --git a/src/squidpy/experimental/tl/_align/_api.py b/src/squidpy/experimental/tl/_align/_api.py index 5e3b555c4..73a004f24 100644 --- a/src/squidpy/experimental/tl/_align/_api.py +++ b/src/squidpy/experimental/tl/_align/_api.py @@ -48,7 +48,6 @@ def align_obs( *, output_mode: Literal["affine", "obs", "return"] = "obs", key_added: str | None = None, - device: Literal["cpu", "gpu"] | None = None, inplace: bool = True, **flavour_kwargs: Any, ) -> AnnData | SpatialData | AlignResult | None: @@ -83,9 +82,6 @@ def align_obs( Name for the aligned table when ``output_mode='obs'`` and inputs are SpatialData. Defaults to ``'{adata_query_name}_aligned'``. Rejected with any other ``output_mode``. - device - ``'cpu'``/``'gpu'`` to force a JAX device, or ``None`` to let JAX - pick the default. Only consulted by JAX-backed flavours. inplace If ``True``, mutate the query container; otherwise return a copy. Only affects SpatialData inputs -- for plain AnnData with @@ -99,7 +95,7 @@ def align_obs( pair = resolve_obs_pair(data_ref, data_query, adata_ref_name, adata_query_name) backend = get_backend(flavour) - result = backend.align_obs(pair, device=device, **flavour_kwargs) + result = backend.align_obs(pair, **flavour_kwargs) # Auto-generate key_added for SpatialData obs writeback. if key_added is None and output_mode == "obs" and pair.query_element_key is not None: @@ -118,7 +114,6 @@ def align_images( scale_ref: str | Literal["auto"] = "auto", scale_query: str | Literal["auto"] = "auto", output_mode: Literal["affine", "return"] = "affine", - device: Literal["cpu", "gpu"] | None = None, inplace: bool = True, **flavour_kwargs: Any, ) -> SpatialData | AlignResult | None: @@ -145,7 +140,7 @@ def align_images( ``'affine'`` registers the fit on the query image element so all of its scales inherit the transformation; ``'return'`` returns the raw :class:`AlignResult`. - device, inplace, flavour_kwargs + inplace, flavour_kwargs See :func:`align_obs`. """ validate_flavour(flavour, allowed=ALLOWED_FLAVOURS_IMAGES, op="align_images") @@ -160,7 +155,7 @@ def align_images( scale_query=scale_query, ) backend = get_backend(flavour) - result = backend.align_images(pair, device=device, **flavour_kwargs) + result = backend.align_images(pair, **flavour_kwargs) return _writeback(pair, result, output_mode=output_mode, key_added=None, inplace=inplace) diff --git a/src/squidpy/experimental/tl/_align/_backends/_base.py b/src/squidpy/experimental/tl/_align/_backends/_base.py index 58ba61657..d39039962 100644 --- a/src/squidpy/experimental/tl/_align/_backends/_base.py +++ b/src/squidpy/experimental/tl/_align/_backends/_base.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: from squidpy.experimental.tl._align._types import AlignPair, AlignResult @@ -24,15 +24,11 @@ class AlignBackend(Protocol): def align_obs( self, pair: AlignPair, - *, - device: Literal["cpu", "gpu"] | None = None, **kwargs: Any, ) -> AlignResult: ... def align_images( self, pair: AlignPair, - *, - device: Literal["cpu", "gpu"] | None = None, **kwargs: Any, ) -> AlignResult: ... diff --git a/src/squidpy/experimental/tl/_align/_backends/_moscot.py b/src/squidpy/experimental/tl/_align/_backends/_moscot.py index 6b5e98e5e..dd9675d84 100644 --- a/src/squidpy/experimental/tl/_align/_backends/_moscot.py +++ b/src/squidpy/experimental/tl/_align/_backends/_moscot.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from squidpy.experimental.tl._align._types import AlignPair, AlignResult @@ -21,13 +21,11 @@ class MoscotBackend: def align_obs( self, pair: AlignPair, - *, - device: Literal["cpu", "gpu"] | None = None, **kwargs: Any, ) -> AlignResult: from squidpy.experimental.tl._align._jax import require_jax - require_jax(device) + require_jax() raise NotImplementedError( "moscot backend `align_obs`: TODO. Skeleton landed; the moscot " "solver will replace this body in a follow-up PR." @@ -36,8 +34,6 @@ def align_obs( def align_images( self, pair: AlignPair, - *, - device: Literal["cpu", "gpu"] | None = None, **kwargs: Any, ) -> AlignResult: raise NotImplementedError("moscot does not implement image alignment; use `flavour='stalign'`.") diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign.py b/src/squidpy/experimental/tl/_align/_backends/_stalign.py index 71605ec52..d3ed503b0 100644 --- a/src/squidpy/experimental/tl/_align/_backends/_stalign.py +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign.py @@ -1,13 +1,13 @@ """STalign backend. -Wraps the JAX LDDMM solver lifted from scverse/squidpy#1150 (Selman Özleyen) +Wraps the JAX LDDMM solver lifted from scverse/squidpy#1150 (Selman Ozleyen) into the :class:`AlignBackend` Protocol. Only ``align_obs`` is implemented today; ``align_images`` raises until upstream support exists. """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any import numpy as np @@ -23,7 +23,6 @@ def align_obs( self, pair: AlignPair, *, - device: Literal["cpu", "gpu"] | None = None, config: Any | None = None, landmarks_source: np.ndarray | None = None, landmarks_target: np.ndarray | None = None, @@ -39,7 +38,7 @@ def align_obs( # `ModuleNotFoundError: import of jax halted; None in sys.modules` # instead of the clean `ImportError("JAX is required ...")` from # _jax.require_jax. - require_jax(device) + require_jax() from squidpy.experimental.tl._align._backends._stalign_tools import stalign_points from squidpy.experimental.tl._align._types import AlignResult, ObsDisplacement @@ -94,8 +93,6 @@ def align_obs( def align_images( self, pair: AlignPair, - *, - device: Literal["cpu", "gpu"] | None = None, **kwargs: Any, ) -> AlignResult: raise NotImplementedError( diff --git a/src/squidpy/experimental/tl/_align/_jax.py b/src/squidpy/experimental/tl/_align/_jax.py index 34411cd82..b18779a6a 100644 --- a/src/squidpy/experimental/tl/_align/_jax.py +++ b/src/squidpy/experimental/tl/_align/_jax.py @@ -1,4 +1,4 @@ -"""Lazy JAX import + device selection for JAX-backed alignment backends. +"""Lazy JAX import guard for JAX-backed alignment backends. JAX is an optional dependency. Importing this module is cheap; calling :func:`require_jax` is what actually pulls JAX in, and only the @@ -7,50 +7,22 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from typing import Any -if TYPE_CHECKING: - Device = Any # jax.Device, but importing it eagerly defeats the purpose +_INSTALL_HINT = 'JAX is required for the requested align_* flavour. Install with `pip install "squidpy[jax]"`.' -_INSTALL_HINT = ( - "JAX is required for the requested align_* flavour. " - "Install with `pip install jax` (CPU) or follow the JAX install guide for GPU." -) - - -def require_jax(device: Literal["cpu", "gpu"] | None = None) -> tuple[Any, Any]: - """Import JAX lazily and return ``(jax, device)``. - - Parameters - ---------- - device - ``"cpu"``/``"gpu"`` to force a platform, or ``None`` to use whatever - JAX picks as the default. - - Returns - ------- - jax_module - The imported :mod:`jax` module. - device - A :class:`jax.Device` of the requested platform. +def require_jax() -> Any: + """Import JAX lazily and return the module. Raises ------ ImportError If JAX is not installed. - RuntimeError - If the requested device platform is not available on this host. """ try: import jax except ImportError as e: raise ImportError(_INSTALL_HINT) from e - if device is None: - return jax, jax.devices()[0] - - matching = [d for d in jax.devices() if d.platform == device] - if not matching: - raise RuntimeError(f"No JAX device of kind {device!r} available; have {[d.platform for d in jax.devices()]}.") - return jax, matching[0] + return jax