diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index c75993038..386e85aa3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -104,6 +104,8 @@ jobs: env: { name: hatch-test.py3.12-stable } - os: macos-latest env: { name: hatch-test.py3.13-pre } # pre-release only needs one OS + - os: macos-latest + env: { name: hatch-test.py3.13-stable-jax } # JAX only needs one OS - os: ubuntu-latest env: { name: hatch-test.py3.13-stable } # skipping because we run this as a coverage job name: ${{ matrix.env.label }} (${{ matrix.os }}) diff --git a/hatch.toml b/hatch.toml index 08afad3b9..b870ce00d 100644 --- a/hatch.toml +++ b/hatch.toml @@ -39,6 +39,11 @@ download = "python ./.scripts/ci/download_data.py {args}" deps = ["stable"] python = ["3.11", "3.12", "3.13"] +[[envs.hatch-test.matrix]] +deps = ["stable"] +python = ["3.13"] +extras = ["jax"] + [[envs.hatch-test.matrix]] deps = ["pre"] python = ["3.13"] @@ -47,6 +52,9 @@ python = ["3.13"] matrix.deps.env-vars = [ { key = "UV_PRERELEASE", value = "allow", if = ["pre"] }, ] +matrix.extras.features = [ + { value = "jax", if = ["jax"] }, +] [envs.notebooks] extra-dependencies = [ diff --git a/pyproject.toml b/pyproject.toml index 06e9dfc5a..8b936d749 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,9 @@ optional-dependencies.docs = [ "sphinxcontrib-bibtex>=2.3", "sphinxcontrib-spelling>=7.6.2", ] +optional-dependencies.jax = [ + "jax", +] optional-dependencies.leiden = [ "leidenalg", "spatialleiden>=0.4", diff --git a/src/squidpy/experimental/__init__.py b/src/squidpy/experimental/__init__.py index 435cd0098..5f4c695ab 100644 --- a/src/squidpy/experimental/__init__.py +++ b/src/squidpy/experimental/__init__.py @@ -6,6 +6,6 @@ from __future__ import annotations -from . import im, pl +from . import im, pl, tl -__all__ = ["im", "pl"] +__all__ = ["im", "pl", "tl"] diff --git a/src/squidpy/experimental/tl/__init__.py b/src/squidpy/experimental/tl/__init__.py new file mode 100644 index 000000000..a9bb596d9 --- /dev/null +++ b/src/squidpy/experimental/tl/__init__.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from squidpy.experimental.tl._align import align_by_landmarks, align_obs + +if TYPE_CHECKING: + from squidpy.experimental.tl._align._backends._stalign_tools import ( + STalignConfig, + STalignPreprocessConfig, + STalignPreprocessResult, + STalignRegistrationConfig, + STalignResult, + ) + +__all__ = [ + "STalignConfig", + "STalignPreprocessConfig", + "STalignPreprocessResult", + "STalignRegistrationConfig", + "STalignResult", + "align_by_landmarks", + "align_obs", +] + +_STALIGN_REEXPORTS = frozenset( + { + "STalignConfig", + "STalignPreprocessConfig", + "STalignPreprocessResult", + "STalignRegistrationConfig", + "STalignResult", + } +) + + +def __getattr__(name: str) -> Any: + """Lazy access to the JAX-only STalign config dataclasses. + + Importing :mod:`squidpy.experimental.tl._align._backends._stalign_tools` pulls in + :mod:`jax` at module-load time, so we defer the import until the first + attribute access. This preserves the lazy-import contract pinned by + ``test_optional_deps_not_imported_at_import_time``. + """ + if name in _STALIGN_REEXPORTS: + try: + from squidpy.experimental.tl._align._backends import _stalign_tools as _tools + except ModuleNotFoundError as e: + if e.name == "jax": + raise ImportError( + 'STalign requires the optional dependency `jax`. Install it with `pip install "squidpy[jax]"`.' + ) from e + raise + return getattr(_tools, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/squidpy/experimental/tl/_align/__init__.py b/src/squidpy/experimental/tl/_align/__init__.py new file mode 100644 index 000000000..51f9cc617 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/__init__.py @@ -0,0 +1,19 @@ +"""Alignment skeleton under :mod:`squidpy.experimental.tl`. + +Public surface: + +- :func:`align_obs` -- align two ``obs``-level point clouds (cells / spots). +- :func:`align_by_landmarks` -- closed-form fit from user-provided landmarks. + +Optional backends (``stalign``, ``moscot``) and JAX are imported lazily -- only +the function call that needs them pulls them in. +""" + +from __future__ import annotations + +from squidpy.experimental.tl._align._api import ( + align_by_landmarks, + align_obs, +) + +__all__ = ["align_by_landmarks", "align_obs"] diff --git a/src/squidpy/experimental/tl/_align/_api.py b/src/squidpy/experimental/tl/_align/_api.py new file mode 100644 index 000000000..73a004f24 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_api.py @@ -0,0 +1,269 @@ +"""Public ``align_*`` orchestrators. + +Each function is intentionally thin: resolve inputs, validate, dispatch to a +backend, write the result back. All branching on argument shape lives in +:mod:`._io`; all backend selection lives in :mod:`._backends`; all validation +of "passed-but-unneeded" combinations lives in :mod:`._validation`. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np + +from squidpy.experimental.tl._align._backends import get_backend +from squidpy.experimental.tl._align._io import ( + apply_affine_to_cs, + materialise_obs, + resolve_element_pair, + resolve_image_pair, + resolve_obs_pair, +) +from squidpy.experimental.tl._align._types import AffineTransform, AlignPair, AlignResult +from squidpy.experimental.tl._align._validation import ( + ALLOWED_FLAVOURS_IMAGES, + ALLOWED_FLAVOURS_OBS, + ALLOWED_OUTPUT_MODES_NONOBS, + ALLOWED_OUTPUT_MODES_OBS, + validate_flavour, + validate_key_added, + validate_landmark_model, + validate_landmarks, + validate_output_mode, +) + +if TYPE_CHECKING: + from anndata import AnnData + from spatialdata import SpatialData + + +def align_obs( + data_ref: AnnData | SpatialData, + data_query: AnnData | SpatialData | None = None, + adata_ref_name: str | None = None, + adata_query_name: str | None = None, + flavour: Literal["stalign", "moscot"] = "stalign", + *, + output_mode: Literal["affine", "obs", "return"] = "obs", + key_added: str | None = None, + inplace: bool = True, + **flavour_kwargs: Any, +) -> AnnData | SpatialData | AlignResult | None: + """Align two ``obs``-level point clouds (cells / spots). + + Parameters + ---------- + data_ref, data_query + Either both :class:`anndata.AnnData`, both :class:`spatialdata.SpatialData`, + or ``data_ref`` a SpatialData and ``data_query=None`` (in which case + ``adata_ref_name`` and ``adata_query_name`` select two different + tables of the same SpatialData object). + adata_ref_name, adata_query_name + Required only when SpatialData inputs are used. Passing them with + AnnData inputs raises an educational :class:`ValueError`. + flavour + Backend to use. ``'stalign'`` is the default LDDMM-based fit; + ``'moscot'`` is OT-based. + output_mode + How to deliver the result: + + - ``'obs'`` (default) -- bake the fit into a new AnnData whose + ``obsm['spatial']`` already lives in the reference coordinate + system. For SpatialData inputs the new table is stored under + ``key_added`` (auto-generated from the query name when omitted). + - ``'affine'`` -- register the fitted affine on the query element via + :func:`spatialdata.transformations.set_transformation`, so every + element in the query coordinate system inherits the alignment. + Requires the backend to produce an affine transform. + - ``'return'`` -- return the raw :class:`AlignResult`; no writeback. + key_added + Name for the aligned table when ``output_mode='obs'`` and inputs are + SpatialData. Defaults to ``'{adata_query_name}_aligned'``. + Rejected with any other ``output_mode``. + inplace + If ``True``, mutate the query container; otherwise return a copy. + Only affects SpatialData inputs -- for plain AnnData with + ``output_mode='obs'``, the aligned AnnData is always returned. + **flavour_kwargs + Backend-specific knobs forwarded as-is to the chosen backend. + """ + validate_flavour(flavour, allowed=ALLOWED_FLAVOURS_OBS, op="align_obs") + validate_output_mode(output_mode, allowed=ALLOWED_OUTPUT_MODES_OBS, op="align_obs") + validate_key_added(key_added, output_mode) + + pair = resolve_obs_pair(data_ref, data_query, adata_ref_name, adata_query_name) + backend = get_backend(flavour) + result = backend.align_obs(pair, **flavour_kwargs) + + # Auto-generate key_added for SpatialData obs writeback. + if key_added is None and output_mode == "obs" and pair.query_element_key is not None: + key_added = f"{pair.query_element_key}_aligned" + + return _writeback(pair, result, output_mode=output_mode, key_added=key_added, inplace=inplace) + + +def align_images( + sdata_ref: SpatialData, + sdata_query: SpatialData | None = None, + *, + img_ref_name: str, + img_query_name: str, + flavour: Literal["stalign"] = "stalign", + scale_ref: str | Literal["auto"] = "auto", + scale_query: str | Literal["auto"] = "auto", + output_mode: Literal["affine", "return"] = "affine", + inplace: bool = True, + **flavour_kwargs: Any, +) -> SpatialData | AlignResult | None: + """Align two raster images living inside :class:`spatialdata.SpatialData`. + + .. note:: + + No backend currently implements image alignment. This function is + reserved for a follow-up PR and is not yet part of the public API. + + Parameters + ---------- + sdata_ref, sdata_query + SpatialData containers. Pass ``sdata_query=None`` to align two + images of the same SpatialData against each other. + img_ref_name, img_query_name + Image element keys. + flavour + Only ``'stalign'`` is currently supported. + scale_ref, scale_query + Scale level for multi-scale image elements. ``'auto'`` picks the + coarsest level. Single-scale images ignore this parameter. + output_mode + ``'affine'`` registers the fit on the query image element so all of + its scales inherit the transformation; ``'return'`` returns the raw + :class:`AlignResult`. + inplace, flavour_kwargs + See :func:`align_obs`. + """ + validate_flavour(flavour, allowed=ALLOWED_FLAVOURS_IMAGES, op="align_images") + validate_output_mode(output_mode, allowed=ALLOWED_OUTPUT_MODES_NONOBS, op="align_images") + + pair = resolve_image_pair( + sdata_ref, + sdata_query, + img_ref_name, + img_query_name, + scale_ref=scale_ref, + scale_query=scale_query, + ) + backend = get_backend(flavour) + result = backend.align_images(pair, **flavour_kwargs) + + return _writeback(pair, result, output_mode=output_mode, key_added=None, inplace=inplace) + + +def align_by_landmarks( + sdata_ref: SpatialData, + sdata_query: SpatialData | None = None, + *, + cs_name_ref: str, + cs_name_query: str, + landmarks_ref: Sequence[tuple[float, float]] | np.ndarray, + landmarks_query: Sequence[tuple[float, float]] | np.ndarray, + model: Literal["similarity", "affine"] = "similarity", + output_mode: Literal["affine", "return"] = "affine", + inplace: bool = True, +) -> SpatialData | AlignResult | None: + """Align by a closed-form fit on user-provided landmarks. + + Pure NumPy under the hood -- JAX is **not** required for this path. + + Parameters + ---------- + sdata_ref, sdata_query + SpatialData containers. Pass ``sdata_query=None`` to align two + coordinate systems of the same SpatialData against each other. + cs_name_ref, cs_name_query + Coordinate system names. + landmarks_ref, landmarks_query + Equal-length sequences of ``(x, y)`` tuples in the pixel space of + the respective coordinate system. ``model='similarity'`` needs + at least 3 pairs, ``model='affine'`` needs at least 3. + Landmarks are validated against the coordinate-system extent + (via :func:`spatialdata.get_extent`) to catch scale mismatches + early. + model + ``'similarity'`` (rotation + uniform scale + translation) or + ``'affine'`` (full 6-parameter linear). + output_mode, inplace + See :func:`align_obs`. + """ + validate_output_mode(output_mode, allowed=ALLOWED_OUTPUT_MODES_NONOBS, op="align_by_landmarks") + validate_landmark_model(model) + + # Fetch coordinate-system extents for landmark bounds checking. + cs_ref_extent = _get_cs_extent(sdata_ref, cs_name_ref) + sdata_query_resolved = sdata_query if sdata_query is not None else sdata_ref + cs_query_extent = _get_cs_extent(sdata_query_resolved, cs_name_query) + + ref_arr, query_arr = validate_landmarks( + landmarks_ref, + landmarks_query, + model=model, + cs_ref_extent=cs_ref_extent, + cs_query_extent=cs_query_extent, + ) + + pair = resolve_element_pair(sdata_ref, sdata_query, cs_name_ref, cs_name_query) + + from squidpy.experimental.tl._align._backends._landmark import fit_landmark_affine + + affine = fit_landmark_affine( + ref_arr, + query_arr, + model=model, + source_cs=cs_name_query, + target_cs=cs_name_ref, + ) + result = AlignResult(transform=affine, metadata={"flavour": "landmark", "model": model}) + + return _writeback(pair, result, output_mode=output_mode, key_added=None, inplace=inplace) + + +def _get_cs_extent( + sdata: SpatialData, + cs_name: str, +) -> tuple[float, float, float, float]: + """Return ``(x_min, y_min, x_max, y_max)`` for a coordinate system.""" + from spatialdata import get_extent + + extent = get_extent(sdata, coordinate_system=cs_name) + return (extent["x"][0], extent["y"][0], extent["x"][1], extent["y"][1]) + + +# --------------------------------------------------------------------------- +# Internal: writeback dispatch +# --------------------------------------------------------------------------- + + +def _writeback( + pair: AlignPair, + result: AlignResult, + *, + output_mode: str, + key_added: str | None, + inplace: bool, +) -> AnnData | SpatialData | AlignResult | None: + if output_mode == "return": + return result + + if output_mode == "affine": + if not isinstance(result.transform, AffineTransform): + raise TypeError( + f"`output_mode='affine'` requires the backend to return an AffineTransform, " + f"got {type(result.transform).__name__}. Use `output_mode='obs'` (for " + f"`align_obs`) or `output_mode='return'` to access non-affine fits." + ) + return apply_affine_to_cs(pair, result.transform, inplace=inplace) + + # output_mode == "obs" -- the only remaining valid branch after + # validate_output_mode has already rejected unknown values. + return materialise_obs(pair, result, key_added=key_added, inplace=inplace) diff --git a/src/squidpy/experimental/tl/_align/_backends/__init__.py b/src/squidpy/experimental/tl/_align/_backends/__init__.py new file mode 100644 index 000000000..5a30e0eb6 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/__init__.py @@ -0,0 +1,26 @@ +"""Backend dispatch for the alignment skeleton. + +Imports of individual backends happen *inside* the dispatch branches so that +``import squidpy.experimental.tl`` never pulls in ``stalign``, ``moscot``, or +``jax`` transitively. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from squidpy.experimental.tl._align._backends._base import AlignBackend + + +def get_backend(flavour: str) -> AlignBackend: + """Return a backend instance for the requested ``flavour``.""" + if flavour == "stalign": + from squidpy.experimental.tl._align._backends._stalign import StAlignBackend + + return StAlignBackend() + if flavour == "moscot": + from squidpy.experimental.tl._align._backends._moscot import MoscotBackend + + return MoscotBackend() + raise ValueError(f"Unknown alignment flavour {flavour!r}; expected 'stalign' or 'moscot'.") diff --git a/src/squidpy/experimental/tl/_align/_backends/_base.py b/src/squidpy/experimental/tl/_align/_backends/_base.py new file mode 100644 index 000000000..d39039962 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/_base.py @@ -0,0 +1,34 @@ +"""Backend Protocol shared by every alignment flavour.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + from squidpy.experimental.tl._align._types import AlignPair, AlignResult + + +@runtime_checkable +class AlignBackend(Protocol): + """Minimal contract every alignment backend must satisfy. + + Backends are constructed cheaply (no heavy imports in ``__init__``) and + only pull in their optional dependencies on the first call into ``align_obs`` + or ``align_images``. ``requires_jax`` advertises whether the backend + needs JAX so callers / dispatch can short-circuit. + """ + + name: str + requires_jax: bool + + def align_obs( + self, + pair: AlignPair, + **kwargs: Any, + ) -> AlignResult: ... + + def align_images( + self, + pair: AlignPair, + **kwargs: Any, + ) -> AlignResult: ... diff --git a/src/squidpy/experimental/tl/_align/_backends/_landmark.py b/src/squidpy/experimental/tl/_align/_backends/_landmark.py new file mode 100644 index 000000000..11bf8475a --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/_landmark.py @@ -0,0 +1,106 @@ +"""Closed-form landmark fit. + +Two models, both pure NumPy / no JAX: + +- ``"similarity"`` (4 DOF: rotation + uniform scale + translation, plus an + optional reflection check) - delegated to + :func:`spatialdata.transformations.get_transformation_between_landmarks`. +- ``"affine"`` (6 DOF: rotation + non-uniform scale + shear + translation) - + delegated to :func:`skimage.transform.estimate_transform`, the same + least-squares solver spatialdata uses internally. + +Useful as a one-shot alignment when you already have corresponding landmarks, +and as a sanity-check baseline for the much heavier STalign LDDMM solver. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +import numpy as np + +if TYPE_CHECKING: + from squidpy.experimental.tl._align._types import AffineTransform + + +def fit_landmark_affine( + landmarks_ref: np.ndarray, + landmarks_query: np.ndarray, + *, + model: Literal["similarity", "affine"] = "similarity", + source_cs: str | None = None, + target_cs: str | None = None, +) -> AffineTransform: + """Fit a 2D affine that maps ``landmarks_query`` onto ``landmarks_ref``. + + Both inputs are ``(N, 2)`` ``(x, y)`` arrays of corresponding landmarks + (the ``i``-th row of ``landmarks_query`` matches the ``i``-th row of + ``landmarks_ref``). ``N`` must be at least 3. + + Parameters + ---------- + landmarks_ref, landmarks_query + Corresponding landmark coordinates in ``(x, y)`` convention. + model + ``"similarity"`` (4 DOF, via spatialdata) or ``"affine"`` (6 DOF, + via skimage). + source_cs, target_cs + Optional coordinate-system labels stamped onto the returned + :class:`AffineTransform` for traceability. + """ + from squidpy.experimental.tl._align._types import AffineTransform + + ref = np.asarray(landmarks_ref, dtype=float) + query = np.asarray(landmarks_query, dtype=float) + + if model == "similarity": + matrix = _fit_similarity_via_spatialdata(ref, query) + elif model == "affine": + matrix = _fit_affine_via_skimage(ref, query) + else: + raise ValueError(f"Unknown landmark `model={model!r}`; expected 'similarity' or 'affine'.") + + return AffineTransform(matrix=matrix, source_cs=source_cs, target_cs=target_cs) + + +def _fit_similarity_via_spatialdata(ref_xy: np.ndarray, query_xy: np.ndarray) -> np.ndarray: + """4-DOF similarity fit, delegated to spatialdata.""" + from spatialdata.models import PointsModel + from spatialdata.transformations import get_transformation_between_landmarks + + refs_pts = PointsModel.parse(ref_xy) + moving_pts = PointsModel.parse(query_xy) + sd_transform = get_transformation_between_landmarks(refs_pts, moving_pts) + return _extract_affine_matrix(sd_transform) + + +def _fit_affine_via_skimage(ref_xy: np.ndarray, query_xy: np.ndarray) -> np.ndarray: + """Full 6-DOF affine fit, delegated to skimage's least-squares solver. + + This is what :func:`spatialdata.transformations.get_transformation_between_landmarks` + uses under the hood before collapsing to a similarity; for the affine + model we keep the raw matrix instead. + """ + from skimage.transform import estimate_transform + + model_obj = estimate_transform("affine", src=query_xy, dst=ref_xy) + return np.asarray(model_obj.params) + + +def _extract_affine_matrix(sd_transform: object) -> np.ndarray: + """Pull a ``(3, 3)`` homogeneous matrix out of a spatialdata transformation. + + :func:`get_transformation_between_landmarks` may return either a single + :class:`spatialdata.transformations.Affine` or a + :class:`spatialdata.transformations.Sequence` of two affines (when a + reflection is detected and rolled into the fit). Use + ``to_affine_matrix`` to collapse either back to a single 3x3. + """ + from spatialdata.transformations import Affine as SDAffine + from spatialdata.transformations import Sequence as SDSequence + + if isinstance(sd_transform, SDAffine): + return np.asarray(sd_transform.matrix) + if isinstance(sd_transform, SDSequence): + return np.asarray(sd_transform.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))) + raise TypeError(f"Unexpected transformation type from spatialdata: {type(sd_transform).__name__}.") diff --git a/src/squidpy/experimental/tl/_align/_backends/_moscot.py b/src/squidpy/experimental/tl/_align/_backends/_moscot.py new file mode 100644 index 000000000..dd9675d84 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/_moscot.py @@ -0,0 +1,39 @@ +"""Moscot backend stub. + +Moscot only exposes ``align_obs``; image alignment is not a moscot use case. +The dispatch layer rejects ``flavour='moscot'`` for :func:`align_images` +before ever reaching this file, so the ``align_images`` method below is here +purely to satisfy the :class:`AlignBackend` Protocol. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from squidpy.experimental.tl._align._types import AlignPair, AlignResult + + +class MoscotBackend: + name = "moscot" + requires_jax = True + + def align_obs( + self, + pair: AlignPair, + **kwargs: Any, + ) -> AlignResult: + from squidpy.experimental.tl._align._jax import require_jax + + require_jax() + raise NotImplementedError( + "moscot backend `align_obs`: TODO. Skeleton landed; the moscot " + "solver will replace this body in a follow-up PR." + ) + + def align_images( + self, + pair: AlignPair, + **kwargs: Any, + ) -> AlignResult: + raise NotImplementedError("moscot does not implement image alignment; use `flavour='stalign'`.") diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign.py b/src/squidpy/experimental/tl/_align/_backends/_stalign.py new file mode 100644 index 000000000..d3ed503b0 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign.py @@ -0,0 +1,101 @@ +"""STalign backend. + +Wraps the JAX LDDMM solver lifted from scverse/squidpy#1150 (Selman Ozleyen) +into the :class:`AlignBackend` Protocol. Only ``align_obs`` is implemented +today; ``align_images`` raises until upstream support exists. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np + +if TYPE_CHECKING: + from squidpy.experimental.tl._align._types import AlignPair, AlignResult + + +class StAlignBackend: + name = "stalign" + requires_jax = True + + def align_obs( + self, + pair: AlignPair, + *, + config: Any | None = None, + landmarks_source: np.ndarray | None = None, + landmarks_target: np.ndarray | None = None, + **kwargs: Any, + ) -> AlignResult: + from anndata import AnnData + + from squidpy.experimental.tl._align._jax import require_jax + + # Resolve JAX *before* importing the lifted _tools module, because + # _tools does `import jax.numpy as jnp` at module level. If we let + # that import fire first, callers without JAX get a confusing + # `ModuleNotFoundError: import of jax halted; None in sys.modules` + # instead of the clean `ImportError("JAX is required ...")` from + # _jax.require_jax. + require_jax() + + from squidpy.experimental.tl._align._backends._stalign_tools import stalign_points + from squidpy.experimental.tl._align._types import AlignResult, ObsDisplacement + + if not isinstance(pair.ref, AnnData) or not isinstance(pair.query, AnnData): + raise TypeError( + "stalign backend `align_obs` only supports AnnData / table inputs; " + f"got ref={type(pair.ref).__name__}, query={type(pair.query).__name__}." + ) + if "spatial" not in pair.query.obsm or "spatial" not in pair.ref.obsm: + raise KeyError("Both ref and query must carry an `obsm['spatial']` point cloud.") + + src_xy = np.asarray(pair.query.obsm["spatial"], dtype=float) + tgt_xy = np.asarray(pair.ref.obsm["spatial"], dtype=float) + + # stalign_points runs internally in row_col (yx); obsm["spatial"] is xy + # by squidpy convention -- swap axes at the boundary. + src_rc = src_xy[:, [1, 0]] + tgt_rc = tgt_xy[:, [1, 0]] + landmarks_src_rc = None if landmarks_source is None else np.asarray(landmarks_source)[:, [1, 0]] + landmarks_tgt_rc = None if landmarks_target is None else np.asarray(landmarks_target)[:, [1, 0]] + + stalign_result = stalign_points( + source_points=src_rc, + target_points=tgt_rc, + config=config, + landmarks_source=landmarks_src_rc, + landmarks_target=landmarks_tgt_rc, + ) + + aligned_rc = np.asarray(stalign_result.aligned_points) + aligned_xy = aligned_rc[:, [1, 0]] + deltas_xy = aligned_xy - src_xy + + return AlignResult( + transform=ObsDisplacement( + deltas=deltas_xy, + source_cs=pair.query_cs, + target_cs=pair.ref_cs, + ), + metadata={ + "flavour": "stalign", + # Escape hatch: the full STalignResult (velocity field, + # velocity grid, affine init) for power users who need + # the diffeomorphic map. This keeps the JAX arrays alive + # in memory -- callers who only need the displacement + # should drop this key or use ``output_mode='obs'``. + "stalign_result": stalign_result, + }, + ) + + def align_images( + self, + pair: AlignPair, + **kwargs: Any, + ) -> AlignResult: + raise NotImplementedError( + "stalign image alignment is not yet implemented; PR #1150 only ships " + "point-cloud alignment. Use `flavour='stalign'` with `align_obs`." + ) diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign_core.py b/src/squidpy/experimental/tl/_align/_backends/_stalign_core.py new file mode 100644 index 000000000..8ac139b6c --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign_core.py @@ -0,0 +1,371 @@ +"""Core JAX implementation for experimental STalign point registration. + +Lifted byte-for-byte from scverse/squidpy#1150 (Selman Özleyen). +""" + +from __future__ import annotations + +from typing import Any, Literal + +import jax +import jax.numpy as jnp +import jax.scipy as jsp +import numpy as np + +__all__ = ["jax_dtype", "lddmm", "transform_points_row_col"] + + +def jax_dtype() -> jnp.dtype: + """Resolve the active JAX float dtype at call time, not import time.""" + return jnp.float64 if jax.config.x64_enabled else jnp.float32 + + +def _to_affine(linear: Any, translation: Any) -> Any: + return jnp.array( + [ + [linear[0, 0], linear[0, 1], translation[0]], + [linear[1, 0], linear[1, 1], translation[1]], + [0.0, 0.0, 1.0], + ], + dtype=linear.dtype, + ) + + +def _grid_points(x: tuple[Any, Any]) -> Any: + yy, xx = jnp.meshgrid(x[0], x[1], indexing="ij") + return jnp.stack((yy, xx)) + + +def _interp( + x: tuple[Any, Any], + image: Any, + phii: Any, + *, + mode: str = "nearest", +) -> Any: + """Interpolate a channels-first image on physical row-column coordinates.""" + arr = jnp.asarray(image) + coords = jnp.asarray(phii) + if coords.shape[0] != 2: + raise ValueError(f"Expected interpolation coordinates to have leading axis of size 2, found `{coords.shape}`.") + if arr.ndim == 2: + arr = arr[None, ...] + + row_step = x[0][1] - x[0][0] + col_step = x[1][1] - x[1][0] + row_idx = (coords[0] - x[0][0]) / row_step + col_idx = (coords[1] - x[1][0]) / col_step + idx = jnp.stack((row_idx.reshape(-1), col_idx.reshape(-1))) + + def _sample(channel: Any) -> Any: + values = jsp.ndimage.map_coordinates(channel, idx, order=1, mode=mode) + return values.reshape(coords.shape[1:]) + + return jax.vmap(_sample)(arr) + + +def transform_points_row_col( + xv: tuple[Any, Any], + velocity: Any, + affine: Any, + points: Any, + *, + direction: Literal["forward", "backward"] = "forward", +) -> Any: + pts = jnp.asarray(points) + n_steps = velocity.shape[0] + time_steps = range(n_steps) + flow_sign = 1.0 + if direction == "backward": + affine = jnp.linalg.inv(affine) + pts = pts @ affine[:2, :2].T + affine[:2, -1] + flow_sign = -1.0 + time_steps = reversed(time_steps) + + for t in time_steps: + disp = _interp( + xv, + jnp.moveaxis(flow_sign * velocity[t], -1, 0), + pts.T[:, :, None], + mode="nearest", + )[:, :, 0].T + pts = pts + disp / n_steps + + if direction == "forward": + pts = pts @ affine[:2, :2].T + affine[:2, -1] + + return pts + + +def _transform_grid_backward( + x_target: tuple[Any, Any], + xv: tuple[Any, Any], + velocity: Any, + affine: Any, +) -> Any: + target_grid = _grid_points(x_target) + affine_inv = jnp.linalg.inv(affine) + source_grid = jnp.einsum("ij,jhw->ihw", affine_inv[:2, :2], target_grid) + affine_inv[:2, -1][:, None, None] + + for t in range(velocity.shape[0] - 1, -1, -1): + disp = _interp(xv, jnp.moveaxis(-velocity[t], -1, 0), source_grid, mode="nearest") + source_grid = source_grid + disp / velocity.shape[0] + + return source_grid + + +def _contrast_transform(source_image: Any, target_image: Any, weights: Any) -> Any: + flat_source = source_image.reshape(source_image.shape[0], -1) + flat_target = target_image.reshape(target_image.shape[0], -1) + flat_weights = weights.reshape(-1) + + design = jnp.concatenate((jnp.ones((1, flat_source.shape[1]), dtype=source_image.dtype), flat_source), axis=0) + weighted_design = design * flat_weights[None, :] + design_cov = weighted_design @ design.T + target_cov = weighted_design @ flat_target.T + regularized = design_cov + 0.1 * jnp.eye(design_cov.shape[0], dtype=design_cov.dtype) + coefficients = jnp.linalg.solve(regularized, target_cov) + return (coefficients.T @ design).reshape(target_image.shape) + + +def _build_velocity_grid(x_source: tuple[Any, Any], *, a: float, expand: float) -> tuple[Any, Any]: + minimum = jnp.array([x_source[0][0], x_source[1][0]]) + maximum = jnp.array([x_source[0][-1], x_source[1][-1]]) + center = (minimum + maximum) / 2.0 + half_width = (maximum - minimum) * expand / 2.0 + step = a * 0.5 + return ( + jnp.arange(center[0] - half_width[0], center[0] + half_width[0] + step, step), + jnp.arange(center[1] - half_width[1], center[1] + half_width[1] + step, step), + ) + + +def _build_regularizer( + xv: tuple[Any, Any], + *, + a: float, + p: float, +) -> tuple[Any, Any, Any]: + dv = jnp.array([xv[0][1] - xv[0][0], xv[1][1] - xv[1][0]]) + shape = (xv[0].shape[0], xv[1].shape[0]) + fy = jnp.arange(shape[0], dtype=xv[0].dtype) / (shape[0] * dv[0]) + fx = jnp.arange(shape[1], dtype=xv[1].dtype) / (shape[1] * dv[1]) + frequency_grid = jnp.stack(jnp.meshgrid(fy, fx, indexing="ij"), axis=-1) + ll = (1.0 + 2.0 * a**2 * jnp.sum((1.0 - jnp.cos(2.0 * np.pi * frequency_grid * dv)) / (dv**2), axis=-1)) ** ( + 2.0 * p + ) + kernel = 1.0 / ll + dv_prod = jnp.prod(dv) + return kernel, ll, dv_prod + + +def _update_mixture_weights( + transformed_source: Any, + target_image: Any, + match_weights: Any, + artifact_weights: Any, + background_weights: Any, + *, + sigmaM: float, + sigmaA: float, + sigmaB: float, + estimate_muA: bool, + estimate_muB: bool, + muA: Any, + muB: Any, + iteration: int, +) -> tuple[Any, Any, Any, Any, Any]: + if estimate_muA: + muA = jnp.sum(artifact_weights * target_image, axis=(-1, -2)) / jnp.maximum(jnp.sum(artifact_weights), 1e-12) + if estimate_muB: + muB = jnp.sum(background_weights * target_image, axis=(-1, -2)) / jnp.maximum( + jnp.sum(background_weights), 1e-12 + ) + + if iteration < 50: + return match_weights, artifact_weights, background_weights, muA, muB + + weights = jnp.stack((match_weights, artifact_weights, background_weights)) + mixing = jnp.sum(weights, axis=(1, 2)) + mixing = mixing + jnp.max(mixing) * 1e-6 + mixing = mixing / jnp.sum(mixing) + + n_channels = target_image.shape[0] + norm_match = (2.0 * np.pi * sigmaM**2) ** (n_channels / 2.0) + norm_artifact = (2.0 * np.pi * sigmaA**2) ** (n_channels / 2.0) + norm_background = (2.0 * np.pi * sigmaB**2) ** (n_channels / 2.0) + + match_weights = mixing[0] * jnp.exp(-jnp.sum((transformed_source - target_image) ** 2, axis=0) / (2.0 * sigmaM**2)) + match_weights = match_weights / norm_match + artifact_weights = mixing[1] * jnp.exp( + -jnp.sum((muA[:, None, None] - target_image) ** 2, axis=0) / (2.0 * sigmaA**2) + ) + artifact_weights = artifact_weights / norm_artifact + background_weights = mixing[2] * jnp.exp( + -jnp.sum((muB[:, None, None] - target_image) ** 2, axis=0) / (2.0 * sigmaB**2) + ) + background_weights = background_weights / norm_background + + total = match_weights + artifact_weights + background_weights + total = total + jnp.max(total) * 1e-6 + return match_weights / total, artifact_weights / total, background_weights / total, muA, muB + + +def _lddmm_loss( + linear: Any, + translation: Any, + velocity: Any, + *, + x_source: tuple[Any, Any], + source_image: Any, + x_target: tuple[Any, Any], + target_image: Any, + xv: tuple[Any, Any], + match_weights: Any, + ll: Any, + dv_prod: Any, + points_source: Any, + points_target: Any, + sigmaM: float, + sigmaR: float, + sigmaP: float, +) -> tuple[Any, tuple[Any, Any, Any, Any, Any]]: + affine = _to_affine(linear, translation) + source_grid = _transform_grid_backward(x_target, xv, velocity, affine) + warped_source = _interp(x_source, source_image, source_grid, mode="nearest") + contrast_source = _contrast_transform(warped_source, target_image, match_weights) + + match_energy = jnp.sum((contrast_source - target_image) ** 2 * match_weights) / (2.0 * sigmaM**2) + fft_velocity = jnp.fft.fftn(velocity, axes=(1, 2)) + reg_energy = ( + jnp.sum(jnp.sum(jnp.abs(fft_velocity) ** 2, axis=(0, 3)) * ll) + * dv_prod + / 2.0 + / velocity.shape[1] + / velocity.shape[2] + / sigmaR**2 + ) + + transformed_points = transform_points_row_col(xv, velocity, affine, points_source, direction="forward") + if points_source.shape[0] == 0: + point_energy = jnp.array(0.0, dtype=source_image.dtype) + else: + point_energy = jnp.sum((transformed_points - points_target) ** 2) / (2.0 * sigmaP**2) + + total = match_energy + reg_energy + point_energy + return total, (contrast_source, transformed_points, match_energy, reg_energy, point_energy) + + +def lddmm( + xI: tuple[Any, Any], + I: Any, + xJ: tuple[Any, Any], + J: Any, + *, + L: Any, + T: Any, + points_source: Any | None = None, + points_target: Any | None = None, + a: float = 500.0, + p: float = 2.0, + expand: float = 2.0, + nt: int = 3, + niter: int = 5000, + diffeo_start: int = 0, + epL: float = 2e-8, + epT: float = 2e-1, + epV: float = 2e3, + sigmaM: float = 1.0, + sigmaB: float = 2.0, + sigmaA: float = 5.0, + sigmaR: float = 5e5, + sigmaP: float = 2e1, +) -> dict[str, Any]: + x_source = (jnp.asarray(xI[0]), jnp.asarray(xI[1])) + x_target = (jnp.asarray(xJ[0]), jnp.asarray(xJ[1])) + source_image = jnp.asarray(I, dtype=jax_dtype()) + target_image = jnp.asarray(J, dtype=jax_dtype()) + linear = jnp.asarray(L, dtype=jax_dtype()) + translation = jnp.asarray(T, dtype=jax_dtype()) + + if points_source is None: + source_landmarks = jnp.zeros((0, 2), dtype=jax_dtype()) + target_landmarks = jnp.zeros((0, 2), dtype=jax_dtype()) + else: + source_landmarks = jnp.asarray(points_source, dtype=jax_dtype()) + target_landmarks = jnp.asarray(points_target, dtype=jax_dtype()) + + xv = _build_velocity_grid(x_source, a=a, expand=expand) + velocity = jnp.zeros((nt, xv[0].shape[0], xv[1].shape[0], 2), dtype=jax_dtype()) + kernel, ll, dv_prod = _build_regularizer(xv, a=a, p=p) + + match_weights = jnp.full(target_image.shape[1:], 0.5, dtype=target_image.dtype) + background_weights = jnp.full(target_image.shape[1:], 0.4, dtype=target_image.dtype) + artifact_weights = jnp.full(target_image.shape[1:], 0.1, dtype=target_image.dtype) + muA = jnp.mean(target_image, axis=(1, 2)) + muB = jnp.zeros_like(muA) + estimate_muA = True + estimate_muB = True + + loss_and_grad = jax.jit(jax.value_and_grad(_lddmm_loss, argnums=(0, 1, 2), has_aux=True)) + + for iteration in range(niter): + (energy, aux), (grad_linear, grad_translation, grad_velocity) = loss_and_grad( + linear, + translation, + velocity, + x_source=x_source, + source_image=source_image, + x_target=x_target, + target_image=target_image, + xv=xv, + match_weights=match_weights, + ll=ll, + dv_prod=dv_prod, + points_source=source_landmarks, + points_target=target_landmarks, + sigmaM=sigmaM, + sigmaR=sigmaR, + sigmaP=sigmaP, + ) + contrast_source, transformed_points, _, _, _ = aux + + affine_scale = 1.0 + 9.0 * float(iteration >= diffeo_start) + linear = linear - (epL / affine_scale) * grad_linear + translation = translation - (epT / affine_scale) * grad_translation + + grad_velocity = jnp.fft.ifftn( + jnp.fft.fftn(grad_velocity, axes=(1, 2)) * kernel[None, ..., None], + axes=(1, 2), + ).real + if iteration >= diffeo_start: + velocity = velocity - epV * grad_velocity + + if iteration % 5 == 0: + match_weights, artifact_weights, background_weights, muA, muB = _update_mixture_weights( + contrast_source, + target_image, + match_weights, + artifact_weights, + background_weights, + sigmaM=sigmaM, + sigmaA=sigmaA, + sigmaB=sigmaB, + estimate_muA=estimate_muA, + estimate_muB=estimate_muB, + muA=muA, + muB=muB, + iteration=iteration, + ) + + affine = _to_affine(linear, translation) + return { + "A": affine, + "v": velocity, + "xv": xv, + "WM": match_weights, + "WB": background_weights, + "WA": artifact_weights, + "E": energy, + "points": transformed_points, + } diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign_helpers.py b/src/squidpy/experimental/tl/_align/_backends/_stalign_helpers.py new file mode 100644 index 000000000..e85212fda --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign_helpers.py @@ -0,0 +1,183 @@ +"""Helpers for experimental STalign point-cloud registration. + +Lifted byte-for-byte from scverse/squidpy#1150 (Selman Özleyen). +""" + +from __future__ import annotations + +from typing import Literal + +import numpy as np +from anndata import AnnData + +PointOrder = Literal["row_col", "xy"] + +__all__ = [ + "PointOrder", + "affine_from_points", + "extract_landmarks", + "extract_points", + "rasterize", +] + + +def _validate_points(points: np.ndarray, *, name: str) -> np.ndarray: + arr = np.asarray(points, dtype=float) + if arr.ndim != 2 or arr.shape[1] != 2: + raise ValueError(f"Expected `{name}` to have shape `(n, 2)`, found `{arr.shape}`.") + if not np.all(np.isfinite(arr)): + raise ValueError(f"Expected `{name}` to contain only finite values.") + return arr + + +def extract_points(adata: AnnData, key: str = "spatial") -> np.ndarray: + """Return a validated coordinate array from ``adata.obsm``.""" + if key not in adata.obsm: + raise KeyError(f"Key `{key}` not found in `adata.obsm`.") + + return _validate_points(np.asarray(adata.obsm[key]), name=f"adata.obsm[{key!r}]") + + +def extract_landmarks(adata: AnnData, key: str) -> np.ndarray: + """Return a validated landmark array from ``adata.obsm`` or ``adata.uns``.""" + if key in adata.obsm: + arr = np.asarray(adata.obsm[key], dtype=float) + if arr.ndim != 2 or arr.shape[1] != 2: + raise ValueError(f"Expected `adata.obsm[{key!r}]` to have shape `(n, 2)`, found `{arr.shape}`.") + mask = np.all(np.isfinite(arr), axis=1) + landmarks = arr[mask] + if landmarks.size == 0: + raise ValueError(f"No finite landmark rows were found in `adata.obsm[{key!r}]`.") + return landmarks + + if key in adata.uns: + return _validate_points(np.asarray(adata.uns[key]), name=f"adata.uns[{key!r}]") + + raise KeyError(f"Key `{key}` not found in `adata.obsm` or `adata.uns`.") + + +def to_row_col(points: np.ndarray, *, point_order: PointOrder) -> np.ndarray: + """Convert coordinates to row-column order.""" + arr = _validate_points(points, name="points") + if point_order == "row_col": + return arr + if point_order == "xy": + return arr[:, [1, 0]] + raise ValueError(f"Unknown `point_order`: `{point_order}`.") + + +def from_row_col(points: np.ndarray, *, point_order: PointOrder) -> np.ndarray: + """Convert row-column coordinates to the requested order.""" + arr = _validate_points(points, name="points") + if point_order == "row_col": + return arr + if point_order == "xy": + return arr[:, [1, 0]] + raise ValueError(f"Unknown `point_order`: `{point_order}`.") + + +def _normalize(values: np.ndarray) -> np.ndarray: + values = np.asarray(values, dtype=float) + vmin = np.min(values) + vmax = np.max(values) + if np.isclose(vmin, vmax): + return np.ones_like(values, dtype=float) + return (values - vmin) / (vmax - vmin) + + +def rasterize( + x: np.ndarray, + y: np.ndarray, + *, + g: np.ndarray | None = None, + dx: float = 30.0, + blur: float | list[float] = 1.0, + expand: float = 1.1, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Rasterize a point cloud into a multi-scale density image.""" + x = np.asarray(x, dtype=float) + y = np.asarray(y, dtype=float) + if x.ndim != 1 or y.ndim != 1 or x.shape != y.shape: + raise ValueError("Expected `x` and `y` to be 1D arrays with the same length.") + if x.size == 0: + raise ValueError("Expected at least one point to rasterize.") + if dx <= 0: + raise ValueError("Expected `dx` to be positive.") + if expand <= 0: + raise ValueError("Expected `expand` to be positive.") + + blur_values = np.atleast_1d(np.asarray(blur, dtype=float)) + if blur_values.ndim != 1 or np.any(blur_values <= 0): + raise ValueError("Expected `blur` to be a positive scalar or a 1D sequence of positive values.") + + if g is None: + weights = np.ones_like(x, dtype=float) + else: + weights = np.asarray(g, dtype=float) + if weights.shape != x.shape: + raise ValueError("Expected `g` to have the same shape as `x` and `y`.") + if not np.allclose(weights, 1.0): + weights = _normalize(weights) + + min_x = float(np.min(x)) + max_x = float(np.max(x)) + min_y = float(np.min(y)) + max_y = float(np.max(y)) + + center_x = (min_x + max_x) / 2.0 + center_y = (min_y + max_y) / 2.0 + half_x = (max_x - min_x) * expand / 2.0 + half_y = (max_y - min_y) * expand / 2.0 + + grid_x = np.arange(center_x - half_x, center_x + half_x + dx, dx, dtype=float) + grid_y = np.arange(center_y - half_y, center_y + half_y + dx, dx, dtype=float) + if grid_x.size < 2 or grid_y.size < 2: + raise ValueError("Rasterized grid is too small. Increase the point spread or lower `dx`.") + + mesh_x, mesh_y = np.meshgrid(grid_x, grid_y) + out = np.zeros((len(blur_values), grid_y.size, grid_x.size), dtype=float) + radius = int(np.ceil(float(np.max(blur_values)) * 4.0)) + + for x_i, y_i, w_i in zip(x, y, weights, strict=False): + col = int(np.rint((x_i - grid_x[0]) / dx)) + row = int(np.rint((y_i - grid_y[0]) / dx)) + + row0 = max(row - radius, 0) + row1 = min(row + radius, out.shape[1] - 1) + col0 = max(col - radius, 0) + col1 = min(col + radius, out.shape[2] - 1) + + patch_x = mesh_x[row0 : row1 + 1, col0 : col1 + 1] + patch_y = mesh_y[row0 : row1 + 1, col0 : col1 + 1] + denom = 2.0 * (dx * blur_values * 2.0) ** 2 + + kernels = np.exp(-((patch_x[..., None] - x_i) ** 2 + (patch_y[..., None] - y_i) ** 2) / denom) + kernels_sum = kernels.sum(axis=(0, 1), keepdims=True) + kernels /= np.where(kernels_sum == 0.0, 1.0, kernels_sum) + out[:, row0 : row1 + 1, col0 : col1 + 1] += np.moveaxis(kernels * w_i, -1, 0) + + return grid_x, grid_y, out + + +def affine_from_points( + points_source: np.ndarray, + points_target: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Compute an affine initialization from corresponding landmarks.""" + source = _validate_points(points_source, name="points_source") + target = _validate_points(points_target, name="points_target") + if source.shape != target.shape: + raise ValueError( + f"Expected `points_source` and `points_target` to have the same shape, found " + f"`{source.shape}` and `{target.shape}`." + ) + + if source.shape[0] < 3: + linear = np.eye(2, dtype=float) + translation = np.mean(target, axis=0) - np.mean(source, axis=0) + return linear, translation + + source_h = np.concatenate((source, np.ones((source.shape[0], 1), dtype=float)), axis=1) + target_h = np.concatenate((target, np.ones((target.shape[0], 1), dtype=float)), axis=1) + affine = np.linalg.lstsq(source_h, target_h, rcond=None)[0].T + return affine[:2, :2], affine[:2, -1] diff --git a/src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py b/src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py new file mode 100644 index 000000000..2098466ab --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py @@ -0,0 +1,255 @@ +"""Low-level point-cloud tools for experimental STalign. + +Lifted from scverse/squidpy#1150 (Selman Özleyen) with import paths +adjusted and minor cleanups (config unpacking, lazy dtype resolution). +""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING, Any, Literal, TypeAlias + +import jax.numpy as jnp +import numpy as np +from anndata import AnnData + +from squidpy.experimental.tl._align._backends._stalign_core import jax_dtype, lddmm, transform_points_row_col +from squidpy.experimental.tl._align._backends._stalign_helpers import ( + PointOrder, + affine_from_points, + extract_points, + from_row_col, + rasterize, + to_row_col, +) + +if TYPE_CHECKING: + import jax + + JaxArray = jax.Array +else: # pragma: no cover - typing only + JaxArray = Any + +BlurScales: TypeAlias = float | tuple[float, ...] | list[float] + +__all__ = [ + "STalignConfig", + "STalignPreprocessConfig", + "STalignPreprocessResult", + "STalignRegistrationConfig", + "STalignResult", + "stalign_points", + "stalign_preprocess", + "transform_points", +] + + +@dataclass(slots=True) +class STalignPreprocessConfig: + dx: float = 30.0 + blur: BlurScales = (2.0, 1.0, 0.5) + expand: float = 1.1 + + +@dataclass(slots=True) +class STalignRegistrationConfig: + """LDDMM registration hyperparameters. + + Field names (``sigmaM``, ``epL``, etc.) preserve the conventions from + the STalign paper and reference implementation to keep them + recognisable when cross-referencing the literature. + """ + + a: float = 500.0 + p: float = 2.0 + expand: float = 2.0 + nt: int = 3 + niter: int = 5000 + diffeo_start: int = 0 + epL: float = 2e-8 + epT: float = 2e-1 + epV: float = 2e3 + sigmaM: float = 1.0 + sigmaB: float = 2.0 + sigmaA: float = 5.0 + sigmaR: float = 5e5 + sigmaP: float = 2e1 + + +@dataclass(slots=True) +class STalignConfig: + preprocess: STalignPreprocessConfig = field(default_factory=STalignPreprocessConfig) + registration: STalignRegistrationConfig = field(default_factory=STalignRegistrationConfig) + + +@dataclass(slots=True) +class STalignPreprocessResult: + source_grid: tuple[np.ndarray, np.ndarray] + source_image: np.ndarray + target_grid: tuple[np.ndarray, np.ndarray] + target_image: np.ndarray + + +@dataclass(slots=True) +class STalignResult: + affine: JaxArray + velocity: JaxArray + velocity_grid: tuple[JaxArray, JaxArray] + aligned_points: JaxArray + point_order: PointOrder = "row_col" + + def transform_points( + self, + points: np.ndarray, + *, + direction: Literal["forward", "backward"] = "forward", + point_order: PointOrder | None = None, + ) -> JaxArray: + """Transform arbitrary point arrays with the fitted map.""" + return transform_points( + self.velocity_grid, + self.velocity, + self.affine, + points, + direction=direction, + point_order=self.point_order if point_order is None else point_order, + ) + + def transform_adata( + self, + adata: AnnData, + *, + spatial_key: str = "spatial", + key_added: str | None = None, + direction: Literal["forward", "backward"] = "forward", + inplace: bool = False, + ) -> np.ndarray | None: + """ + Apply the fitted transform to coordinates stored on an AnnData object. + + If ``inplace=False``, return the transformed coordinates without + modifying ``adata``. If ``inplace=True``, write the transformed + coordinates to ``adata.obsm[spatial_key]`` or ``adata.obsm[key_added]`` + and return ``None``. + """ + points = extract_points(adata, key=spatial_key) + transformed = np.asarray(self.transform_points(points, direction=direction, point_order="xy")) + if not inplace: + return transformed + + adata.obsm[spatial_key if key_added is None else key_added] = transformed + return None + + +def stalign_preprocess( + source_points: np.ndarray, + target_points: np.ndarray, + *, + config: STalignPreprocessConfig | None = None, +) -> STalignPreprocessResult: + """Rasterize source and target point clouds for LDDMM registration.""" + config = STalignPreprocessConfig() if config is None else config + source_points = to_row_col(source_points, point_order="row_col") + target_points = to_row_col(target_points, point_order="row_col") + + source_x, source_y, source_image = rasterize( + source_points[:, 1], + source_points[:, 0], + dx=config.dx, + blur=config.blur, + expand=config.expand, + ) + target_x, target_y, target_image = rasterize( + target_points[:, 1], + target_points[:, 0], + dx=config.dx, + blur=config.blur, + expand=config.expand, + ) + + return STalignPreprocessResult( + source_grid=(source_y, source_x), + source_image=source_image, + target_grid=(target_y, target_x), + target_image=target_image, + ) + + +def transform_points( + xv: tuple[JaxArray, JaxArray], + v: JaxArray, + A: JaxArray, + points: np.ndarray, + *, + direction: Literal["forward", "backward"] = "forward", + point_order: PointOrder = "row_col", +) -> JaxArray: + """Transform point arrays with a fitted STalign map.""" + points_rc = to_row_col(points, point_order=point_order) + transformed = transform_points_row_col( + xv, + jnp.asarray(v), + jnp.asarray(A), + jnp.asarray(points_rc, dtype=jax_dtype()), + direction=direction, + ) + return jnp.asarray(from_row_col(np.asarray(transformed), point_order=point_order)) + + +def stalign_points( + source_points: np.ndarray, + target_points: np.ndarray, + *, + preprocessed: STalignPreprocessResult | None = None, + config: STalignConfig | None = None, + landmarks_source: np.ndarray | None = None, + landmarks_target: np.ndarray | None = None, +) -> STalignResult: + """Align source point cloud to target with a JAX LDDMM solver.""" + config = STalignConfig() if config is None else config + registration = config.registration + source_points = to_row_col(source_points, point_order="row_col") + target_points = to_row_col(target_points, point_order="row_col") + if preprocessed is None: + preprocessed = stalign_preprocess(source_points, target_points, config=config.preprocess) + + if (landmarks_source is None) != (landmarks_target is None): + raise ValueError("Expected both landmark arrays to be provided together.") + + if landmarks_source is None: + linear = np.eye(2, dtype=float) + translation = np.zeros(2, dtype=float) + source_landmarks = None + target_landmarks = None + else: + source_landmarks = to_row_col(landmarks_source, point_order="row_col") + target_landmarks = to_row_col(landmarks_target, point_order="row_col") + linear, translation = affine_from_points(source_landmarks, target_landmarks) + + dtype = jax_dtype() + result = lddmm( + preprocessed.source_grid, + preprocessed.source_image, + preprocessed.target_grid, + preprocessed.target_image, + L=jnp.asarray(linear, dtype=dtype), + T=jnp.asarray(translation, dtype=dtype), + points_source=None if source_landmarks is None else jnp.asarray(source_landmarks, dtype=dtype), + points_target=None if target_landmarks is None else jnp.asarray(target_landmarks, dtype=dtype), + **asdict(registration), + ) + aligned_points = transform_points( + result["xv"], + result["v"], + result["A"], + source_points, + direction="forward", + point_order="row_col", + ) + return STalignResult( + affine=result["A"], + velocity=result["v"], + velocity_grid=result["xv"], + aligned_points=aligned_points, + point_order="row_col", + ) diff --git a/src/squidpy/experimental/tl/_align/_io.py b/src/squidpy/experimental/tl/_align/_io.py new file mode 100644 index 000000000..dd6c83a92 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_io.py @@ -0,0 +1,326 @@ +"""Input resolvers and output writers for the alignment skeleton. + +This module is the *only* place that knows about the duck-typed +``AnnData | SpatialData`` argument shape of the public functions and the +``output_mode`` writeback strategies. Backends operate on the canonical +:class:`AlignPair` produced here. +""" + +from __future__ import annotations + +from typing import Literal + +import numpy as np +from anndata import AnnData +from spatialdata import SpatialData + +from squidpy._validators import assert_isinstance, assert_key_in_sdata +from squidpy.experimental.im._utils import get_element_data +from squidpy.experimental.tl._align._types import ( + AffineTransform, + AlignPair, + AlignResult, +) +from squidpy.experimental.tl._align._validation import ( + validate_required, + validate_unexpected, +) + +# --------------------------------------------------------------------------- +# Resolvers +# --------------------------------------------------------------------------- + + +def resolve_obs_pair( + data_ref: AnnData | SpatialData, + data_query: AnnData | SpatialData | None, + adata_ref_name: str | None, + adata_query_name: str | None, +) -> AlignPair: + """Normalise the arguments of :func:`align_obs` into an :class:`AlignPair`. + + See the table in the design doc for the exhaustive case matrix. In short: + + - both AnnData → use directly, ``adata_*_name`` must be ``None``; + - both SpatialData → ``adata_*_name`` required, extract from each; + - only ``data_ref`` is SpatialData (``data_query is None``) → both + ``adata_*_name`` required, extract from the same sdata; + - mixed AnnData/SpatialData → :class:`TypeError`; + - ``data_ref`` is AnnData with no ``data_query`` → :class:`ValueError`. + """ + if isinstance(data_ref, AnnData): + if data_query is None: + raise ValueError("`data_query` is required when `data_ref` is an AnnData.") + if not isinstance(data_query, AnnData): + raise TypeError( + f"Mixed AnnData/SpatialData inputs are not supported. " + f"`data_ref` is AnnData but `data_query` is {type(data_query).__name__}." + ) + validate_unexpected( + name="adata_ref_name", + value=adata_ref_name, + when="`data_ref` is a SpatialData", + hint="Both inputs are AnnData, so there is no table to look up by name.", + ) + validate_unexpected( + name="adata_query_name", + value=adata_query_name, + when="`data_query` is a SpatialData", + hint="Both inputs are AnnData, so there is no table to look up by name.", + ) + return AlignPair(ref=data_ref, query=data_query) + + if not isinstance(data_ref, SpatialData): + raise TypeError(f"`data_ref` must be AnnData or SpatialData, got {type(data_ref).__name__}.") + + if data_query is None: + sdata_query: SpatialData = data_ref + elif isinstance(data_query, SpatialData): + sdata_query = data_query + else: + raise TypeError( + f"Mixed AnnData/SpatialData inputs are not supported. " + f"`data_ref` is SpatialData but `data_query` is {type(data_query).__name__}." + ) + + validate_required(name="adata_ref_name", value=adata_ref_name, when="`data_ref` is a SpatialData") + validate_required(name="adata_query_name", value=adata_query_name, when="`data_query` is a SpatialData") + assert_key_in_sdata(data_ref, adata_ref_name, attr="tables") + assert_key_in_sdata(sdata_query, adata_query_name, attr="tables") + return AlignPair( + ref=data_ref.tables[adata_ref_name], + query=sdata_query.tables[adata_query_name], + ref_container=data_ref, + query_container=sdata_query, + ref_element_key=adata_ref_name, + query_element_key=adata_query_name, + ) + + +def resolve_image_pair( + sdata_ref: SpatialData, + sdata_query: SpatialData | None, + img_ref_name: str, + img_query_name: str, + *, + scale_ref: str | Literal["auto"] = "auto", + scale_query: str | Literal["auto"] = "auto", +) -> AlignPair: + """Normalise the arguments of :func:`align_images` into an :class:`AlignPair`. + + Both single-scale ``xr.DataArray`` and multi-scale ``xr.DataTree`` image + elements are accepted. Multiscale nodes are flattened via + :func:`squidpy.experimental.im._utils.get_element_data`, but the original + element node is remembered in the :class:`AlignPair` so the writer can + register the transformation on the parent so all scales inherit it. + """ + assert_isinstance(sdata_ref, SpatialData, name="sdata_ref") + if sdata_query is None: + sdata_query = sdata_ref + else: + assert_isinstance(sdata_query, SpatialData, name="sdata_query") + + assert_key_in_sdata(sdata_ref, img_ref_name, attr="images") + assert_key_in_sdata(sdata_query, img_query_name, attr="images") + + ref_node = sdata_ref.images[img_ref_name] + query_node = sdata_query.images[img_query_name] + + ref_data = get_element_data(ref_node, scale_ref, element_type="image", element_key=img_ref_name) + query_data = get_element_data(query_node, scale_query, element_type="image", element_key=img_query_name) + + return AlignPair( + ref=ref_data, + query=query_data, + ref_container=sdata_ref, + query_container=sdata_query, + ref_element_key=img_ref_name, + query_element_key=img_query_name, + ) + + +def resolve_element_pair( + sdata_ref: SpatialData, + sdata_query: SpatialData | None, + cs_name_ref: str, + cs_name_query: str, +) -> AlignPair: + """Normalise the arguments of :func:`align_by_landmarks` into an :class:`AlignPair`. + + No element data is materialised — landmark fitting only needs the + coordinate system names plus the landmark coordinates, which are + validated separately. The returned pair carries the containers and cs + names so the writer can call :func:`set_transformation` on the right + target. + """ + assert_isinstance(sdata_ref, SpatialData, name="sdata_ref") + if sdata_query is None: + sdata_query = sdata_ref + else: + assert_isinstance(sdata_query, SpatialData, name="sdata_query") + + _check_cs_exists(sdata_ref, cs_name_ref, name="cs_name_ref") + _check_cs_exists(sdata_query, cs_name_query, name="cs_name_query") + + return AlignPair( + ref=None, + query=None, + ref_container=sdata_ref, + query_container=sdata_query, + ref_cs=cs_name_ref, + query_cs=cs_name_query, + ) + + +def _check_cs_exists(sdata: SpatialData, cs_name: str, *, name: str) -> None: + available = list(sdata.coordinate_systems) + if cs_name not in available: + raise KeyError( + f"`{name}={cs_name!r}` is not a coordinate system of the SpatialData object. " + f"Available coordinate systems: {available}." + ) + + +# --------------------------------------------------------------------------- +# Writeback +# --------------------------------------------------------------------------- + + +def apply_affine_to_cs( + pair: AlignPair, + affine: AffineTransform, + *, + inplace: bool, +) -> SpatialData | AnnData | None: + """Register ``affine`` on the query side of the pair. + + Three writeback paths, in order of specificity: + + 1. **Element-keyed**: ``pair.query_container`` and ``pair.query_element_key`` + are both set (e.g. ``align_obs`` / ``align_images`` resolved an explicit + table or image). Register the transform on that single element so all + scales / sibling tables that share its parent element node inherit it. + 2. **Cs-keyed**: only ``pair.query_cs`` is set (e.g. ``align_by_landmarks`` + resolved a coordinate system but no specific element). Walk every + element that has the moving cs in its transformation graph and register + the transform on each, mapping into the reference cs. + 3. **Plain AnnData**: no spatialdata container at all - warp + ``query.obsm['spatial']`` directly. + """ + from spatialdata.transformations import get_transformation, set_transformation + + target_cs = affine.target_cs or pair.ref_cs or "aligned" + + if pair.query_container is not None and pair.query_element_key is not None: + sdata = pair.query_container if inplace else _shallow_copy_sdata(pair.query_container) + element = _get_element(sdata, pair.query_element_key) + set_transformation(element, affine.to_spatialdata(), to_coordinate_system=target_cs) + return None if inplace else sdata + + if pair.query_container is not None and pair.query_cs is not None: + sdata = pair.query_container if inplace else _shallow_copy_sdata(pair.query_container) + moving_cs = pair.query_cs + sd_affine = affine.to_spatialdata() + touched_any = False + for _etype, _name, element in sdata.gen_elements(): + if isinstance(element, AnnData): + continue + element_transforms = get_transformation(element, get_all=True) + if moving_cs not in element_transforms: + continue + set_transformation(element, sd_affine, to_coordinate_system=target_cs) + touched_any = True + if not touched_any: + raise KeyError( + f"No elements in the query SpatialData are registered to coordinate " + f"system {moving_cs!r}; nothing to attach the alignment to." + ) + return None if inplace else sdata + + if isinstance(pair.query, AnnData): + adata = pair.query if inplace else pair.query.copy() + if "spatial" not in adata.obsm: + raise KeyError("Cannot apply an affine to an AnnData query that has no `obsm['spatial']`.") + adata.obsm["spatial"] = affine.apply(np.asarray(adata.obsm["spatial"])) + return None if inplace else adata + + raise RuntimeError("apply_affine_to_cs: pair has neither a SpatialData container nor an AnnData query.") + + +def materialise_obs( + pair: AlignPair, + result: AlignResult, + *, + key_added: str | None, + inplace: bool, +) -> SpatialData | AnnData | None: + """Bake the transform into a *new* AnnData living in the reference cs. + + For affine results we apply the matrix; for :class:`ObsDisplacement` we + add the deltas. When the source query lives inside a SpatialData, the new + AnnData is registered as ``sdata.tables[key_added]``; otherwise it is + returned directly. + + .. note:: + + The returned AnnData **shares** ``X`` and ``var`` with the source + query by reference to avoid copying potentially-large expression + matrices. Mutating one will affect the other. Call + ``.copy()`` on the result if you need full independence. + """ + if not isinstance(pair.query, AnnData): + raise TypeError("materialise_obs only works for `align_obs`; `pair.query` must be an AnnData.") + if "spatial" not in pair.query.obsm: + raise KeyError("Source AnnData has no `obsm['spatial']` to warp.") + + src_coords = np.asarray(pair.query.obsm["spatial"]) + new_coords = result.transform.apply(src_coords) + + # Slim copy: share X/var/obs structurally and only rewrite obsm so we + # don't pay the cost of deep-copying potentially-large layers/obsp. + new_obsm = dict(pair.query.obsm) + new_obsm["spatial"] = new_coords + new_uns = dict(pair.query.uns) + new_uns["align"] = { + "source_query_key": pair.query_element_key, + "ref_key": pair.ref_element_key, + **result.metadata, + } + new_adata = AnnData( + X=pair.query.X, + obs=pair.query.obs.copy(), + var=pair.query.var, + obsm=new_obsm, + uns=new_uns, + ) + + if pair.query_container is not None: + if key_added is None: + raise ValueError("`key_added` is required when `output_mode='obs'` and the query is a SpatialData.") + sdata = pair.query_container if inplace else _shallow_copy_sdata(pair.query_container) + from spatialdata.models import TableModel + + sdata.tables[key_added] = TableModel.parse(new_adata) + return None if inplace else sdata + + return new_adata + + +def _get_element(sdata: SpatialData, key: str) -> object: + """Look up a spatial element by name across all element types.""" + for attr in ("images", "labels", "points", "shapes", "tables"): + store = getattr(sdata, attr) + if key in store: + return store[key] + raise KeyError(f"Element {key!r} not found in the SpatialData object.") + + +def _shallow_copy_sdata(sdata: SpatialData) -> SpatialData: + """Shallow copy of a SpatialData object for ``inplace=False`` writeback paths. + + Uses :meth:`SpatialData.subset` over every element so tables and + ``attrs`` propagate the same way spatialdata's own subsetting handles + them, rather than reconstructing via the ``__init__`` constructor. + """ + element_names = [name for _, name, _ in sdata.gen_elements()] + return sdata.subset(element_names, filter_tables=False, include_orphan_tables=True) diff --git a/src/squidpy/experimental/tl/_align/_jax.py b/src/squidpy/experimental/tl/_align/_jax.py new file mode 100644 index 000000000..b18779a6a --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_jax.py @@ -0,0 +1,28 @@ +"""Lazy JAX import guard for JAX-backed alignment backends. + +JAX is an optional dependency. Importing this module is cheap; calling +:func:`require_jax` is what actually pulls JAX in, and only the +JAX-backed backends do so on first call. +""" + +from __future__ import annotations + +from typing import Any + +_INSTALL_HINT = 'JAX is required for the requested align_* flavour. Install with `pip install "squidpy[jax]"`.' + + +def require_jax() -> Any: + """Import JAX lazily and return the module. + + Raises + ------ + ImportError + If JAX is not installed. + """ + try: + import jax + except ImportError as e: + raise ImportError(_INSTALL_HINT) from e + + return jax diff --git a/src/squidpy/experimental/tl/_align/_types.py b/src/squidpy/experimental/tl/_align/_types.py new file mode 100644 index 000000000..46e85d528 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_types.py @@ -0,0 +1,120 @@ +"""Dataclasses for the alignment skeleton. + +These types are the contract between the public ``align_*`` functions, the +input/output helpers in :mod:`squidpy.experimental.tl._align._io`, and the +backend implementations in :mod:`squidpy.experimental.tl._align._backends`. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import numpy as np + +if TYPE_CHECKING: + import xarray as xr + from anndata import AnnData + from spatialdata import SpatialData + from spatialdata.transformations import Affine + + +@dataclass(frozen=True) +class AlignPair: + """Canonical pair of aligned-or-to-be-aligned elements. + + Returned by every resolver in :mod:`._io`. ``ref``/``query`` carry the + actual data to fit on; ``*_container``/``*_element_key`` remember where + they came from so the writeback step can register a transformation on the + correct element of the correct :class:`spatialdata.SpatialData`. + """ + + # ``ref``/``query`` are ``None`` for landmark-only flows where the fit + # operates on user-provided coordinates and the resolver only needs to + # locate the target containers + coordinate systems. + ref: AnnData | xr.DataArray | None + query: AnnData | xr.DataArray | None + ref_container: SpatialData | None = None + query_container: SpatialData | None = None + ref_element_key: str | None = None + query_element_key: str | None = None + ref_cs: str | None = None + query_cs: str | None = None + + +@dataclass(frozen=True) +class AffineTransform: + """A ``(3, 3)`` homogeneous affine in ``(x, y)`` convention. + + This matches the coordinate axis order spatialdata uses for points - + ``spatialdata.transformations.get_transformation_between_landmarks`` + asserts ``axes == ("x", "y")`` - and the order squidpy / scanpy use for + ``adata.obsm["spatial"]``. Image elements are stored ``(c, y, x)`` in + spatialdata, so when registering an ``AffineTransform`` on an *image* + element you may need a separate matrix; this skeleton currently only + deals with point coordinates. + """ + + matrix: np.ndarray + source_cs: str | None = None + target_cs: str | None = None + + def __post_init__(self) -> None: + if self.matrix.shape != (3, 3): + raise ValueError(f"Expected a (3, 3) homogeneous matrix, got shape {self.matrix.shape}.") + + def to_spatialdata(self) -> Affine: + """Build a :class:`spatialdata.transformations.Affine` for writeback.""" + from spatialdata.transformations import Affine + + return Affine( + self.matrix, + input_axes=("x", "y"), + output_axes=("x", "y"), + ) + + def apply(self, coords: np.ndarray) -> np.ndarray: + """Apply the affine to an ``(N, 2)`` ``(x, y)`` coordinate array.""" + if coords.ndim != 2 or coords.shape[1] != 2: + raise ValueError(f"Expected an (N, 2) coordinate array, got shape {coords.shape}.") + return coords @ self.matrix[:2, :2].T + self.matrix[:2, 2] + + +@dataclass(frozen=True) +class ObsDisplacement: + """Per-obs ``(N, 2)`` ``(x, y)`` displacement field. + + Used by non-affine fits (e.g. LDDMM) where a single matrix cannot + represent the deformation. Displacements are added to the source + observation coordinates (also ``(x, y)``) to obtain the aligned + coordinates. + """ + + deltas: np.ndarray + source_cs: str | None = None + target_cs: str | None = None + + def __post_init__(self) -> None: + if self.deltas.ndim != 2 or self.deltas.shape[1] != 2: + raise ValueError(f"Expected an (N, 2) deltas array, got shape {self.deltas.shape}.") + + def apply(self, coords: np.ndarray) -> np.ndarray: + """Bake the displacement into an ``(N, 2)`` obs coordinate array.""" + if coords.shape != self.deltas.shape: + raise ValueError(f"Coord shape {coords.shape} does not match displacement shape {self.deltas.shape}.") + return coords + self.deltas + + +Transform = AffineTransform | ObsDisplacement + + +@dataclass(frozen=True) +class AlignResult: + """The output of an alignment backend call.""" + + transform: Transform + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def is_affine(self) -> bool: + return isinstance(self.transform, AffineTransform) diff --git a/src/squidpy/experimental/tl/_align/_validation.py b/src/squidpy/experimental/tl/_align/_validation.py new file mode 100644 index 000000000..e64c005c9 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_validation.py @@ -0,0 +1,143 @@ +"""Validation helpers shared by the public align_* functions. + +These wrap the generic checks in :mod:`squidpy._validators` with messages +tailored to alignment. The goal is to fail fast and tell the user *why* a +combination of arguments is wrong, not just that it is. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +import numpy as np + +from squidpy._validators import assert_one_of + +# Flavour identifiers - shared between dispatch, validation, and the public +# function defaults so a typo lights up everywhere it's used. +STALIGN = "stalign" +MOSCOT = "moscot" + +ALLOWED_FLAVOURS_OBS = (STALIGN, MOSCOT) +ALLOWED_FLAVOURS_IMAGES = (STALIGN,) +ALLOWED_OUTPUT_MODES_OBS = ("affine", "obs", "return") +ALLOWED_OUTPUT_MODES_NONOBS = ("affine", "return") +ALLOWED_LANDMARK_MODELS = ("similarity", "affine") + + +def validate_flavour(flavour: str, *, allowed: Sequence[str], op: str) -> None: + assert_one_of(flavour, allowed, name=f"{op}.flavour") + + +def validate_output_mode(output_mode: str, *, allowed: Sequence[str], op: str) -> None: + assert_one_of(output_mode, allowed, name=f"{op}.output_mode") + + +def validate_key_added(key_added: str | None, output_mode: str) -> None: + """``key_added`` only makes sense in the obs-materialisation path.""" + if key_added is not None and output_mode != "obs": + raise ValueError( + f"`key_added={key_added!r}` is only meaningful when `output_mode='obs'`. " + f"Got `output_mode={output_mode!r}`. The other modes either register a " + f"transformation on the existing element ('affine') or return the raw " + f"result ('return'), so there is nothing for `key_added` to name." + ) + + +def validate_landmark_model(model: str) -> None: + assert_one_of(model, ALLOWED_LANDMARK_MODELS, name="align_by_landmarks.model") + + +def validate_landmarks( + landmarks_ref: Sequence[tuple[float, float]], + landmarks_query: Sequence[tuple[float, float]], + *, + model: str, + cs_ref_extent: tuple[float, float, float, float] | None = None, + cs_query_extent: tuple[float, float, float, float] | None = None, +) -> tuple[np.ndarray, np.ndarray]: + """Validate landmark sequences and return them as ``(N, 2)`` arrays. + + Parameters + ---------- + landmarks_ref, landmarks_query + Sequences of ``(x, y)`` tuples. Must have the same length and at + least 3 entries (the closed-form solvers under both ``model`` + choices need at least 3 corresponding points). + model + ``"similarity"`` (4 DOF, via spatialdata) or ``"affine"`` (6 DOF, + via skimage's least-squares estimator). + cs_ref_extent, cs_query_extent + Optional ``(x_min, y_min, x_max, y_max)`` bounds of the named + coordinate system at the requested scale. When provided, every + landmark must fall inside. Catches the "I extracted these from + scale0 but asked for scale2" footgun. + """ + ref = np.asarray(landmarks_ref, dtype=float) + query = np.asarray(landmarks_query, dtype=float) + + if ref.ndim != 2 or ref.shape[1] != 2: + raise ValueError(f"`landmarks_ref` must be a sequence of (x, y) pairs, got shape {ref.shape}.") + if query.ndim != 2 or query.shape[1] != 2: + raise ValueError(f"`landmarks_query` must be a sequence of (x, y) pairs, got shape {query.shape}.") + if len(ref) != len(query): + raise ValueError( + f"`landmarks_ref` and `landmarks_query` must have the same length; got {len(ref)} and {len(query)}." + ) + + if len(ref) < 3: + raise ValueError( + f"`model={model!r}` needs at least 3 landmark pairs (spatialdata requirement), got {len(ref)}." + ) + + if cs_ref_extent is not None: + _check_in_extent(ref, cs_ref_extent, name="landmarks_ref") + if cs_query_extent is not None: + _check_in_extent(query, cs_query_extent, name="landmarks_query") + + return ref, query + + +def _check_in_extent( + points: np.ndarray, + extent: tuple[float, float, float, float], + *, + name: str, +) -> None: + x_min, y_min, x_max, y_max = extent + out_of_bounds = (points[:, 0] < x_min) | (points[:, 0] > x_max) | (points[:, 1] < y_min) | (points[:, 1] > y_max) + if out_of_bounds.any(): + bad = points[out_of_bounds] + raise ValueError( + f"{name}: {int(out_of_bounds.sum())} landmark(s) fall outside the coordinate-system " + f"extent (x in [{x_min}, {x_max}], y in [{y_min}, {y_max}]). " + f"This usually means the landmarks were extracted at a different scale than the " + f"one requested. First out-of-bounds point: {tuple(bad[0])}." + ) + + +def validate_unexpected( + *, + name: str, + value: Any, + when: str, + hint: str = "", +) -> None: + """Raise an educational error when an argument was passed in a context it has no role in.""" + if value is not None: + msg = f"`{name}={value!r}` was passed but is only valid when {when}." + if hint: + msg = f"{msg} {hint}" + raise ValueError(msg) + + +def validate_required( + *, + name: str, + value: Any, + when: str, +) -> None: + """Raise when an argument is required by the current context but missing.""" + if value is None: + raise ValueError(f"`{name}` is required when {when}.") diff --git a/tests/experimental/tl/__init__.py b/tests/experimental/tl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/experimental/tl/test_align_blobs_e2e.py b/tests/experimental/tl/test_align_blobs_e2e.py new file mode 100644 index 000000000..d7238ef81 --- /dev/null +++ b/tests/experimental/tl/test_align_blobs_e2e.py @@ -0,0 +1,216 @@ +"""End-to-end alignment tests on the spatialdata ``blobs`` fixture. + +Pulls the 200-point ``blobs_points`` cloud from +:func:`spatialdata.datasets.blobs`, applies a known transformation to a +copy, and verifies that both alignment paths recover the inverse: + +- ``align_by_landmarks`` should recover an exact closed-form solution. +- ``align_obs(flavour="stalign")`` should reduce the residual displacement + (the LDDMM is non-affine, so we compare the warped query against the ref + by mean Euclidean distance, not by exact equality). + +These tests exercise real solver iterations. They are slower than the +``niter=1`` smoke tests in ``test_align_stalign_integration.py`` (the +stalign test takes a few seconds) but still run on every commit. The user +can mark them ``slow`` later if CI budget becomes a concern. +""" + +from __future__ import annotations + +import numpy as np +import pytest +from anndata import AnnData + +pytest.importorskip("jax") + + +# --------------------------------------------------------------------------- +# Fixture +# --------------------------------------------------------------------------- + + +def _blobs_points_xy() -> np.ndarray: + """Pull the 200 ``blobs_points`` rows as an ``(N, 2)`` ``(x, y)`` array.""" + from spatialdata.datasets import blobs + + sd = blobs() + pts_df = sd.points["blobs_points"].compute() + return np.column_stack([pts_df["x"].to_numpy(), pts_df["y"].to_numpy()]).astype(float) + + +def _make_blob_adata(coords_xy: np.ndarray) -> AnnData: + """Wrap an ``(N, 2)`` point cloud as an AnnData with ``obsm['spatial']``.""" + adata = AnnData(np.zeros((coords_xy.shape[0], 1), dtype=float)) + adata.obsm["spatial"] = coords_xy + return adata + + +def _rotation_about_centre(theta_rad: float, centre_xy: np.ndarray) -> np.ndarray: + """Build a ``(3, 3)`` homogeneous rotation about a centre, in xy convention.""" + c, s = np.cos(theta_rad), np.sin(theta_rad) + R = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=float) + T_to = np.array([[1.0, 0.0, -centre_xy[0]], [0.0, 1.0, -centre_xy[1]], [0.0, 0.0, 1.0]]) + T_back = np.array([[1.0, 0.0, centre_xy[0]], [0.0, 1.0, centre_xy[1]], [0.0, 0.0, 1.0]]) + return T_back @ R @ T_to + + +def _apply_homog(points_xy: np.ndarray, M: np.ndarray) -> np.ndarray: + return points_xy @ M[:2, :2].T + M[:2, 2] + + +@pytest.fixture(scope="module") +def blobs_rotated() -> tuple[AnnData, AnnData, np.ndarray]: + """``(ref_adata, query_adata, gt_affine)`` for the rotation-recovery tests. + + The query is the reference rotated by 12° around the cloud centroid. + The ground-truth affine ``gt_affine`` maps ``ref -> query``; recovery + means producing a transform that maps ``query -> ref`` (i.e. the + inverse of ``gt_affine``). + """ + ref_xy = _blobs_points_xy() + centre = ref_xy.mean(axis=0) + gt = _rotation_about_centre(np.deg2rad(12.0), centre) + query_xy = _apply_homog(ref_xy, gt) + + ref = _make_blob_adata(ref_xy) + query = _make_blob_adata(query_xy) + return ref, query, gt + + +# --------------------------------------------------------------------------- +# 1. align_by_landmarks recovers the rotation exactly +# --------------------------------------------------------------------------- + + +def test_align_by_landmarks_recovers_blobs_rotation_exactly(blobs_rotated) -> None: + """A handful of correspondences is enough for the closed-form fit to + invert a pure 2D rotation up to numerical noise.""" + from squidpy.experimental.tl._align._backends._landmark import fit_landmark_affine + + ref, query, gt = blobs_rotated + + # Pick 4 landmarks that span the cloud well. + idx = [0, 50, 100, 150] + landmarks_ref = ref.obsm["spatial"][idx] + landmarks_query = query.obsm["spatial"][idx] + + fit = fit_landmark_affine(landmarks_ref, landmarks_query, model="similarity") + inv_gt = np.linalg.inv(gt) + + # Apply the recovered transform to all 200 query points and compare to ref. + recovered = fit.apply(query.obsm["spatial"]) + residual = np.linalg.norm(recovered - ref.obsm["spatial"], axis=1) + assert residual.max() < 1e-6, f"max residual {residual.max():.3e} should be ~0 for a rigid fit" + + # And the matrix itself should match the inverse of gt to high precision. + np.testing.assert_allclose(fit.matrix, inv_gt, atol=1e-9) + + +def test_align_by_landmarks_affine_recovers_blobs_rotation(blobs_rotated) -> None: + """The ``model='affine'`` path also fits a pure rotation correctly, + even though it has 2 extra DOF over the similarity case.""" + from squidpy.experimental.tl._align._backends._landmark import fit_landmark_affine + + ref, query, _gt = blobs_rotated + idx = [0, 30, 60, 90, 120, 150] + fit = fit_landmark_affine( + ref.obsm["spatial"][idx], + query.obsm["spatial"][idx], + model="affine", + ) + + recovered = fit.apply(query.obsm["spatial"]) + residual = np.linalg.norm(recovered - ref.obsm["spatial"], axis=1) + assert residual.max() < 1e-6 + + +# --------------------------------------------------------------------------- +# 2. align_obs (stalign) reduces the residual on the rotated cloud +# --------------------------------------------------------------------------- + + +def test_align_obs_stalign_reduces_residual_on_blobs(blobs_rotated) -> None: + """The LDDMM solver isn't expected to be exact - non-rigid by design - + but feeding it the landmark-fit affine as an init via the + ``landmarks_*`` kwargs should reduce the residual *below* the no-op + baseline by an order of magnitude. This is the wiring proof: we go + from raw misaligned coordinates to substantially-aligned coordinates + end-to-end through ``sq.experimental.tl.align_obs``. + """ + import squidpy as sq + + ref, query, _gt = blobs_rotated + + baseline = np.linalg.norm(ref.obsm["spatial"] - query.obsm["spatial"], axis=1).mean() + + config = sq.experimental.tl.STalignConfig( + preprocess=sq.experimental.tl.STalignPreprocessConfig(dx=20.0, blur=2.0, expand=1.2), + registration=sq.experimental.tl.STalignRegistrationConfig( + a=80.0, + expand=1.2, + nt=3, + niter=80, + epV=5e2, + ), + ) + + # Use the same well-spread landmarks as the closed-form test so the + # affine init is meaningful; LDDMM then refines the diffeomorphism. + idx = [0, 50, 100, 150] + landmarks_ref = ref.obsm["spatial"][idx] + landmarks_query = query.obsm["spatial"][idx] + + aligned = sq.experimental.tl.align_obs( + ref, + query, + flavour="stalign", + output_mode="obs", + inplace=False, + config=config, + landmarks_source=landmarks_query, + landmarks_target=landmarks_ref, + ) + assert isinstance(aligned, AnnData) + + after = np.linalg.norm(ref.obsm["spatial"] - aligned.obsm["spatial"], axis=1).mean() + # Sanity: the residual is finite and the alignment moved the points. + assert np.isfinite(after) + assert after < baseline, f"residual {after:.2f} should improve on baseline {baseline:.2f}" + + +# --------------------------------------------------------------------------- +# 3. The result type is consistent across both backends +# --------------------------------------------------------------------------- + + +def test_blobs_landmark_and_stalign_use_compatible_xy_convention(blobs_rotated) -> None: + """Both backends operate on (x, y) coords drawn from the same blobs + fixture and produce results in the same convention - sanity check that + the type unification didn't introduce a silent flip.""" + import squidpy as sq + from squidpy.experimental.tl._align._backends._landmark import fit_landmark_affine + from squidpy.experimental.tl._align._types import AffineTransform, ObsDisplacement + + ref, query, _ = blobs_rotated + idx = [0, 50, 100, 150] + affine_fit = fit_landmark_affine( + ref.obsm["spatial"][idx], + query.obsm["spatial"][idx], + model="similarity", + ) + assert isinstance(affine_fit, AffineTransform) + assert affine_fit.matrix.shape == (3, 3) + + config = sq.experimental.tl.STalignConfig( + preprocess=sq.experimental.tl.STalignPreprocessConfig(dx=30.0, blur=2.0), + registration=sq.experimental.tl.STalignRegistrationConfig(a=100.0, expand=1.2, nt=1, niter=1, epV=1.0), + ) + stalign_result = sq.experimental.tl.align_obs( + ref, + query, + flavour="stalign", + output_mode="return", + config=config, + ) + assert isinstance(stalign_result.transform, ObsDisplacement) + assert stalign_result.transform.deltas.shape == query.obsm["spatial"].shape diff --git a/tests/experimental/tl/test_align_skeleton.py b/tests/experimental/tl/test_align_skeleton.py new file mode 100644 index 000000000..b6ff07528 --- /dev/null +++ b/tests/experimental/tl/test_align_skeleton.py @@ -0,0 +1,491 @@ +"""Skeleton-level tests for ``squidpy.experimental.tl.align_*``. + +These tests verify *wiring* — argument resolution, dispatch, validation, and +lazy-import hygiene. Real solver tests come with the next PR that drops the +actual implementations into the prepared ``NotImplementedError`` slots. +""" + +from __future__ import annotations + +import sys + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from anndata import AnnData +from spatialdata import SpatialData +from spatialdata.models import Image2DModel, TableModel + +__all__: list[str] = [] # silence the import-only test module check + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_adata(n: int = 8, seed: int = 0) -> AnnData: + rng = np.random.default_rng(seed) + X = rng.standard_normal((n, 3)).astype(np.float32) + adata = AnnData(X=X) + adata.obs["region"] = "r" + adata.obs["instance_id"] = np.arange(n) + adata.obsm["spatial"] = rng.uniform(0, 100, (n, 2)) + return adata + + +def _make_table(name: str, n: int = 8, seed: int = 0) -> AnnData: + adata = _make_adata(n=n, seed=seed) + adata.obs["region"] = pd.Categorical([name] * n) + adata.uns["spatialdata_attrs"] = { + "region": name, + "region_key": "region", + "instance_key": "instance_id", + } + return TableModel.parse(adata) + + +def _make_sdata(image_keys=("img_ref", "img_query"), table_keys=("tbl_ref", "tbl_query")) -> SpatialData: + images = {} + for k in image_keys: + arr = np.zeros((3, 32, 32), dtype=np.uint8) + xa = xr.DataArray(arr, dims=["c", "y", "x"], coords={"c": ["R", "G", "B"]}) + images[k] = Image2DModel.parse(xa) + tables = {k: _make_table(k) for k in table_keys} + return SpatialData(images=images, tables=tables) + + +def _make_sdata_two_cs() -> SpatialData: + """Single SpatialData with two distinct coordinate systems (``cs_a``/``cs_b``). + + Used by the landmark tests where we want to align one cs to another + inside the same container. + """ + from spatialdata.transformations import Identity, set_transformation + + arr = np.zeros((3, 32, 32), dtype=np.uint8) + xa = xr.DataArray(arr, dims=["c", "y", "x"], coords={"c": ["R", "G", "B"]}) + img_a = Image2DModel.parse(xa) + img_b = Image2DModel.parse(xa.copy()) + set_transformation(img_a, Identity(), to_coordinate_system="cs_a") + set_transformation(img_b, Identity(), to_coordinate_system="cs_b") + return SpatialData(images={"img_a": img_a, "img_b": img_b}) + + +@pytest.fixture +def adata_pair() -> tuple[AnnData, AnnData]: + return _make_adata(seed=1), _make_adata(seed=2) + + +@pytest.fixture +def sdata_pair() -> tuple[SpatialData, SpatialData]: + return _make_sdata(), _make_sdata() + + +@pytest.fixture +def sdata_single() -> SpatialData: + return _make_sdata() + + +@pytest.fixture +def sdata_two_cs() -> SpatialData: + return _make_sdata_two_cs() + + +# --------------------------------------------------------------------------- +# 1. Public path +# --------------------------------------------------------------------------- + + +def test_public_callables_exist() -> None: + import squidpy as sq + + assert callable(sq.experimental.tl.align_obs) + assert callable(sq.experimental.tl.align_by_landmarks) + # align_images is not yet public (no backend implements it) + assert not hasattr(sq.experimental.tl, "align_images") + + +# --------------------------------------------------------------------------- +# 2. Lazy-import hygiene +# --------------------------------------------------------------------------- + + +def test_optional_deps_not_imported_at_import_time() -> None: + """Subprocess-isolated import to defeat module caching from other tests. + + A pop+reimport in-process is unreliable here because other tests in the + same session may have already pulled stalign/moscot/jax in transitively. + Spawn a clean Python and inspect its sys.modules. + """ + import subprocess + import sys as _sys + + out = subprocess.check_output( + [ + _sys.executable, + "-c", + ( + "import sys, squidpy; " + "leaked = [m for m in ('jax', 'stalign', 'moscot') if m in sys.modules]; " + "print(','.join(leaked))" + ), + ], + text=True, + timeout=30, + ).strip() + assert out == "", f"Optional deps imported by `import squidpy`: {out}" + + +# --------------------------------------------------------------------------- +# 3. Resolver matrix for align_obs +# --------------------------------------------------------------------------- + + +def test_resolve_obs_pair_two_anndata(adata_pair) -> None: + from squidpy.experimental.tl._align._io import resolve_obs_pair + + a, b = adata_pair + pair = resolve_obs_pair(a, b, None, None) + assert pair.ref is a + assert pair.query is b + assert pair.ref_container is None + assert pair.query_container is None + + +def test_resolve_obs_pair_two_anndata_rejects_unneeded_name(adata_pair) -> None: + from squidpy.experimental.tl._align._io import resolve_obs_pair + + a, b = adata_pair + with pytest.raises(ValueError, match="adata_ref_name"): + resolve_obs_pair(a, b, "tbl_ref", None) + with pytest.raises(ValueError, match="adata_query_name"): + resolve_obs_pair(a, b, None, "tbl_query") + + +def test_resolve_obs_pair_two_sdata(sdata_pair) -> None: + from squidpy.experimental.tl._align._io import resolve_obs_pair + + sa, sb = sdata_pair + pair = resolve_obs_pair(sa, sb, "tbl_ref", "tbl_query") + assert pair.ref_container is sa + assert pair.query_container is sb + assert pair.ref_element_key == "tbl_ref" + assert pair.query_element_key == "tbl_query" + assert isinstance(pair.ref, AnnData) + assert isinstance(pair.query, AnnData) + + +def test_resolve_obs_pair_two_sdata_requires_names(sdata_pair) -> None: + from squidpy.experimental.tl._align._io import resolve_obs_pair + + sa, sb = sdata_pair + with pytest.raises(ValueError, match="adata_ref_name"): + resolve_obs_pair(sa, sb, None, "tbl_query") + with pytest.raises(ValueError, match="adata_query_name"): + resolve_obs_pair(sa, sb, "tbl_ref", None) + + +def test_resolve_obs_pair_single_sdata(sdata_single) -> None: + from squidpy.experimental.tl._align._io import resolve_obs_pair + + pair = resolve_obs_pair(sdata_single, None, "tbl_ref", "tbl_query") + assert pair.ref_container is sdata_single + assert pair.query_container is sdata_single + assert pair.ref_element_key == "tbl_ref" + assert pair.query_element_key == "tbl_query" + + +def test_resolve_obs_pair_single_sdata_same_name_passes_through(sdata_single) -> None: + """Same-name within one sdata is a valid no-op-ish call (identity fit). + + The resolver is not in the business of semantic-uniqueness validation; + backends are free to treat identical inputs however they like. + """ + from squidpy.experimental.tl._align._io import resolve_obs_pair + + pair = resolve_obs_pair(sdata_single, None, "tbl_ref", "tbl_ref") + assert pair.ref_element_key == "tbl_ref" + assert pair.query_element_key == "tbl_ref" + + +def test_resolve_obs_pair_mixed_inputs_rejected(adata_pair, sdata_single) -> None: + from squidpy.experimental.tl._align._io import resolve_obs_pair + + a, _ = adata_pair + with pytest.raises(TypeError, match="Mixed AnnData/SpatialData"): + resolve_obs_pair(a, sdata_single, None, "tbl_ref") + with pytest.raises(TypeError, match="Mixed AnnData/SpatialData"): + resolve_obs_pair(sdata_single, a, "tbl_ref", None) + + +def test_resolve_obs_pair_anndata_without_query(adata_pair) -> None: + from squidpy.experimental.tl._align._io import resolve_obs_pair + + a, _ = adata_pair + with pytest.raises(ValueError, match="`data_query` is required"): + resolve_obs_pair(a, None, None, None) + + +# --------------------------------------------------------------------------- +# 4. Multiscale image resolution +# --------------------------------------------------------------------------- + + +def test_resolve_image_pair_single_and_multiscale() -> None: + from squidpy.experimental.tl._align._io import resolve_image_pair + + # Single-scale image + arr = np.zeros((3, 32, 32), dtype=np.uint8) + xa = xr.DataArray(arr, dims=["c", "y", "x"], coords={"c": ["R", "G", "B"]}) + single = Image2DModel.parse(xa) + + # Multiscale image + multi = Image2DModel.parse(xa, scale_factors=[2]) + + sdata = SpatialData(images={"single": single, "multi": multi}) + + pair = resolve_image_pair(sdata, None, "single", "multi") + assert isinstance(pair.ref, xr.DataArray) + assert isinstance(pair.query, xr.DataArray) + assert pair.ref_element_key == "single" + assert pair.query_element_key == "multi" + + +def test_resolve_image_pair_same_name_passes_through() -> None: + """Same image name in a single sdata is a valid no-op call.""" + from squidpy.experimental.tl._align._io import resolve_image_pair + + arr = np.zeros((3, 16, 16), dtype=np.uint8) + xa = xr.DataArray(arr, dims=["c", "y", "x"], coords={"c": ["R", "G", "B"]}) + sdata = SpatialData(images={"img": Image2DModel.parse(xa)}) + + pair = resolve_image_pair(sdata, None, "img", "img") + assert pair.ref_element_key == "img" + assert pair.query_element_key == "img" + + +# --------------------------------------------------------------------------- +# 5. Landmark validation +# --------------------------------------------------------------------------- + + +def test_validate_landmarks_unequal_length() -> None: + from squidpy.experimental.tl._align._validation import validate_landmarks + + with pytest.raises(ValueError, match="same length"): + validate_landmarks(((0, 0), (1, 1)), ((0, 0),), model="similarity") + + +def test_validate_landmarks_requires_three_points() -> None: + """spatialdata's get_transformation_between_landmarks requires n>=3.""" + from squidpy.experimental.tl._align._validation import validate_landmarks + + with pytest.raises(ValueError, match="at least 3"): + validate_landmarks(((0, 0), (1, 1)), ((0, 0), (1, 1)), model="similarity") + + +def test_validate_landmarks_outside_extent() -> None: + from squidpy.experimental.tl._align._validation import validate_landmarks + + with pytest.raises(ValueError, match="outside the coordinate-system extent"): + validate_landmarks( + ((0, 0), (50, 50), (3, 3)), + ((0, 0), (5, 5), (1, 1)), + model="similarity", + cs_ref_extent=(0, 0, 10, 10), + ) + + +def test_validate_landmarks_happy_path() -> None: + from squidpy.experimental.tl._align._validation import validate_landmarks + + ref, query = validate_landmarks( + ((0, 0), (10, 0), (0, 10)), + ((1, 1), (11, 1), (1, 11)), + model="affine", + cs_ref_extent=(0, 0, 100, 100), + cs_query_extent=(0, 0, 100, 100), + ) + assert ref.shape == (3, 2) + assert query.shape == (3, 2) + + +# --------------------------------------------------------------------------- +# 6. Output-mode guards +# --------------------------------------------------------------------------- + + +def test_align_images_rejects_output_mode_obs(sdata_single) -> None: + from squidpy.experimental.tl._align._api import align_images + + with pytest.raises(ValueError, match="output_mode"): + align_images( + sdata_single, + None, + img_ref_name="img_ref", + img_query_name="img_query", + output_mode="obs", # type: ignore[arg-type] + ) + + +def test_key_added_only_with_obs_mode(sdata_single) -> None: + import squidpy as sq + + with pytest.raises(ValueError, match="key_added"): + sq.experimental.tl.align_obs( + sdata_single, + None, + adata_ref_name="tbl_ref", + adata_query_name="tbl_query", + output_mode="affine", + key_added="aligned", + ) + + +# --------------------------------------------------------------------------- +# 7. Dispatch +# --------------------------------------------------------------------------- + + +def test_align_obs_stalign_image_path_raises(sdata_single) -> None: + """``align_images(flavour='stalign')`` is still NotImplementedError; the + PR-#1150 lift only ships point alignment. This pins the contract that + the dispatch reaches the backend cleanly (no ImportError/AttributeError).""" + from squidpy.experimental.tl._align._api import align_images + + pytest.importorskip("jax") + with pytest.raises(NotImplementedError, match="stalign image alignment"): + align_images( + sdata_single, + None, + img_ref_name="img_ref", + img_query_name="img_query", + flavour="stalign", + ) + + +def test_align_obs_moscot_dispatch_reaches_backend(sdata_single) -> None: + import squidpy as sq + + pytest.importorskip("jax") + with pytest.raises(NotImplementedError, match="moscot backend"): + sq.experimental.tl.align_obs( + sdata_single, + None, + adata_ref_name="tbl_ref", + adata_query_name="tbl_query", + flavour="moscot", + ) + + +def test_align_obs_unknown_flavour(sdata_single) -> None: + import squidpy as sq + + with pytest.raises(ValueError, match="flavour"): + sq.experimental.tl.align_obs( + sdata_single, + None, + adata_ref_name="tbl_ref", + adata_query_name="tbl_query", + flavour="bogus", # type: ignore[arg-type] + ) + + +def test_align_images_rejects_moscot(sdata_single) -> None: + from squidpy.experimental.tl._align._api import align_images + + with pytest.raises(ValueError, match="flavour"): + align_images( + sdata_single, + None, + img_ref_name="img_ref", + img_query_name="img_query", + flavour="moscot", # type: ignore[arg-type] + ) + + +# --------------------------------------------------------------------------- +# 8. align_by_landmarks is JAX-free +# --------------------------------------------------------------------------- + + +def test_align_by_landmarks_two_cs_in_same_sdata(monkeypatch, sdata_two_cs) -> None: + """Align two distinct coordinate systems inside a single SpatialData via + the closed-form spatialdata fit. Also pins that the landmark path never + touches JAX: with ``jax`` blocked the call must still succeed, because + the landmark backend is pure NumPy/spatialdata. + """ + from spatialdata.transformations import Affine, get_transformation + + import squidpy as sq + + monkeypatch.setitem(sys.modules, "jax", None) + + # Three corresponding landmarks: identity translation by (+1, +2). + landmarks_ref = ((0.0, 0.0), (10.0, 0.0), (0.0, 10.0)) + landmarks_query = ((1.0, 2.0), (11.0, 2.0), (1.0, 12.0)) + + sq.experimental.tl.align_by_landmarks( + sdata_two_cs, + None, + cs_name_ref="cs_a", + cs_name_query="cs_b", + landmarks_ref=landmarks_ref, + landmarks_query=landmarks_query, + model="similarity", + ) + + # The fit should now have attached an affine on `img_b` mapping cs_b -> cs_a. + img_b = sdata_two_cs.images["img_b"] + transforms = get_transformation(img_b, get_all=True) + assert "cs_a" in transforms, f"alignment didn't register a cs_a transform on img_b; have {list(transforms)}" + assert isinstance(transforms["cs_a"], Affine) + + +def test_align_by_landmarks_affine_model(sdata_two_cs) -> None: + """The 6-DOF affine model fits via skimage and registers a transform.""" + from spatialdata.transformations import Affine, get_transformation + + import squidpy as sq + + # 4 landmarks (>3, since affine has 6 DOF and skimage wants over-determined input). + landmarks_ref = ((0.0, 0.0), (10.0, 0.0), (0.0, 10.0), (10.0, 10.0)) + landmarks_query = ((1.0, 2.0), (11.0, 2.0), (1.0, 12.0), (11.0, 12.0)) + + sq.experimental.tl.align_by_landmarks( + sdata_two_cs, + None, + cs_name_ref="cs_a", + cs_name_query="cs_b", + landmarks_ref=landmarks_ref, + landmarks_query=landmarks_query, + model="affine", + ) + + img_b = sdata_two_cs.images["img_b"] + transforms = get_transformation(img_b, get_all=True) + assert "cs_a" in transforms + assert isinstance(transforms["cs_a"], Affine) + + +# --------------------------------------------------------------------------- +# 9. JAX-required flavours fail cleanly without JAX +# --------------------------------------------------------------------------- + + +def test_stalign_without_jax_raises_importerror(monkeypatch, sdata_single) -> None: + import squidpy as sq + + # Block the import: the lazy `import jax` inside _jax.require_jax will hit None. + monkeypatch.setitem(sys.modules, "jax", None) + + with pytest.raises(ImportError, match="JAX is required"): + sq.experimental.tl.align_obs( + sdata_single, + None, + adata_ref_name="tbl_ref", + adata_query_name="tbl_query", + flavour="stalign", + ) diff --git a/tests/experimental/tl/test_align_stalign_integration.py b/tests/experimental/tl/test_align_stalign_integration.py new file mode 100644 index 000000000..40e4dc0fa --- /dev/null +++ b/tests/experimental/tl/test_align_stalign_integration.py @@ -0,0 +1,132 @@ +"""Happy-path integration tests for the stalign backend. + +These exercise the lift from scverse/squidpy#1150 through the +``align_obs`` API and the ``output_mode`` writeback paths. Tiny synthetic +fixtures with ``niter=1`` keep them fast enough to run on every commit; they +verify wiring and shapes only, **not** solver quality. Numeric / visual +end-to-end verification on a rotation-recovery fixture is a separate +follow-up (see the plan file). +""" + +from __future__ import annotations + +import numpy as np +import pytest +from anndata import AnnData + +pytest.importorskip("jax") + + +def _make_xy_adata() -> AnnData: + """Five-point synthetic AnnData with an ``obsm['spatial']`` cloud.""" + points_xy = np.array( + [ + [10.0, 1.0], + [12.0, 1.0], + [11.0, 2.0], + [10.0, 3.0], + [12.0, 3.0], + ] + ) + adata = AnnData(np.zeros((points_xy.shape[0], 1))) + adata.obsm["spatial"] = points_xy + return adata + + +def _tiny_config(): + """Single-iteration LDDMM hyperparameters - smallest possible solve.""" + import squidpy as sq + + return sq.experimental.tl.STalignConfig( + preprocess=sq.experimental.tl.STalignPreprocessConfig(dx=0.5, blur=1.0), + registration=sq.experimental.tl.STalignRegistrationConfig( + a=1.0, + expand=1.0, + nt=1, + niter=1, + epV=1.0, + ), + ) + + +def test_align_obs_stalign_return_mode_yields_obs_displacement() -> None: + """Wiring smoke test: dispatch -> stalign LDDMM -> ObsDisplacement.""" + import squidpy as sq + from squidpy.experimental.tl._align._types import AlignResult, ObsDisplacement + + ref = _make_xy_adata() + query = _make_xy_adata() + result = sq.experimental.tl.align_obs( + ref, + query, + flavour="stalign", + output_mode="return", + config=_tiny_config(), + ) + + assert isinstance(result, AlignResult) + assert isinstance(result.transform, ObsDisplacement) + assert result.transform.deltas.shape == query.obsm["spatial"].shape + assert np.all(np.isfinite(result.transform.deltas)) + assert result.metadata["flavour"] == "stalign" + # The escape hatch: the full STalignResult is preserved for power users. + assert "stalign_result" in result.metadata + + +def test_align_obs_stalign_obs_mode_writes_new_anndata() -> None: + """``output_mode='obs'`` materialises a new AnnData in the ref cs.""" + import squidpy as sq + + ref = _make_xy_adata() + query = _make_xy_adata() + aligned = sq.experimental.tl.align_obs( + ref, + query, + flavour="stalign", + output_mode="obs", + inplace=False, + config=_tiny_config(), + ) + + assert isinstance(aligned, AnnData) + assert aligned.obsm["spatial"].shape == query.obsm["spatial"].shape + # The writer stamps `align` metadata on uns so callers can introspect. + assert "align" in aligned.uns + assert aligned.uns["align"]["flavour"] == "stalign" + + +def test_align_obs_stalign_affine_mode_errors_for_non_affine_fit() -> None: + """LDDMM is non-affine; ``output_mode='affine'`` must error helpfully.""" + import squidpy as sq + + ref = _make_xy_adata() + query = _make_xy_adata() + with pytest.raises(TypeError, match="requires the backend to return an AffineTransform"): + sq.experimental.tl.align_obs( + ref, + query, + flavour="stalign", + output_mode="affine", + config=_tiny_config(), + ) + + +def test_align_obs_stalign_with_landmarks() -> None: + """Landmark-guided affine init reaches the solver via flavour_kwargs.""" + import squidpy as sq + + ref = _make_xy_adata() + query = _make_xy_adata() + landmarks_xy = ref.obsm["spatial"][:3] + + result = sq.experimental.tl.align_obs( + ref, + query, + flavour="stalign", + output_mode="return", + config=_tiny_config(), + landmarks_source=landmarks_xy, + landmarks_target=landmarks_xy, + ) + assert "stalign_result" in result.metadata + assert result.transform.deltas.shape == query.obsm["spatial"].shape