Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1431587
init
selmanozleyen Apr 7, 2026
3e385fc
add TODO
selmanozleyen Apr 7, 2026
e4ca9cd
stalign comparison notebooks
selmanozleyen Apr 7, 2026
2040f2b
notebooks
selmanozleyen Apr 7, 2026
50e4b28
docstring
selmanozleyen Apr 7, 2026
763f560
undo todo
selmanozleyen Apr 7, 2026
a35429d
formatting
selmanozleyen Apr 7, 2026
f51f5c1
lazy import + uns serialization
selmanozleyen Apr 7, 2026
d0b3f8f
reverse time
selmanozleyen Apr 7, 2026
310945e
rename helper funcs
selmanozleyen Apr 7, 2026
f4cdebf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2026
ea5a7d2
add jax import errior
selmanozleyen Apr 7, 2026
1e1f300
init
selmanozleyen Apr 7, 2026
ccd2070
add TODO
selmanozleyen Apr 7, 2026
2ce7aa1
stalign comparison notebooks
selmanozleyen Apr 7, 2026
2c035d9
notebooks
selmanozleyen Apr 7, 2026
a15fb52
docstring
selmanozleyen Apr 7, 2026
f3fe06e
undo todo
selmanozleyen Apr 7, 2026
66f10a1
formatting
selmanozleyen Apr 7, 2026
8ecb713
lazy import + uns serialization
selmanozleyen Apr 7, 2026
c61f3e6
reverse time
selmanozleyen Apr 7, 2026
84db738
rename helper funcs
selmanozleyen Apr 7, 2026
95a27bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2026
2dd3a2c
add jax import errior
selmanozleyen Apr 7, 2026
b23eda1
Merge branch 'feat/stalign-points' of https://github.com/selmanozleye…
selmanozleyen Apr 7, 2026
54c6d5a
inplace
selmanozleyen Apr 7, 2026
9696fee
inplace changes
selmanozleyen Apr 7, 2026
29da95f
explicit imports
selmanozleyen Apr 7, 2026
c35a5b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2026
d1dc973
config classes
selmanozleyen Apr 9, 2026
828c6f7
move data
selmanozleyen Apr 9, 2026
77396a8
update readme
selmanozleyen Apr 9, 2026
857fcaf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2026
1a96361
remove the notebooks + data
selmanozleyen Apr 10, 2026
f461e86
Merge branch 'main' into feat/stalign-points
selmanozleyen Apr 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,5 @@ data
pixi.lock

_version.py
# TODO: no idea hwy I needed this
.mplconfig/*
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ optional-dependencies.docs = [
"sphinxcontrib-bibtex>=2.3",
"sphinxcontrib-spelling>=7.6.2",
]
optional-dependencies.jax = [
"jax",
]
optional-dependencies.leiden = [
"leidenalg",
"spatialleiden>=0.4",
Expand Down
4 changes: 2 additions & 2 deletions src/squidpy/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@

from __future__ import annotations

from . import im, pl
from . import im, pl, tl

__all__ = ["im", "pl"]
__all__ = ["im", "pl", "tl"]
30 changes: 30 additions & 0 deletions src/squidpy/experimental/tl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations

from importlib import import_module

__all__ = ["stalign", "stalign_tools"]


def _import_stalign_module(module_name: str):
try:
return import_module(module_name)
except ModuleNotFoundError as e:
if e.name == "jax":
raise ImportError(
'STalign requires the optional dependency `jax`. Install it with `pip install "squidpy[jax]"`.'
) from e
raise


def __getattr__(name: str):
# Module-level lazy imports are a common scientific Python pattern for
# optional or heavy dependencies.
if name == "stalign":
return _import_stalign_module("squidpy.experimental.tl._stalign").stalign
if name == "stalign_tools":
return _import_stalign_module("squidpy.experimental.tl.stalign_tools")
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def __dir__() -> list[str]:
return sorted(__all__)
114 changes: 114 additions & 0 deletions src/squidpy/experimental/tl/_stalign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""AnnData-facing wrappers for experimental STalign."""

from __future__ import annotations

import numpy as np
from anndata import AnnData

from squidpy.experimental.tl._stalign_helpers import extract_landmarks, extract_points
from squidpy.experimental.tl.stalign_tools import STalignConfig, STalignResult, stalign_points

__all__ = ["stalign"]


def stalign(
adata_src: AnnData,
adata_tgt: AnnData,
*,
src_key: str = "spatial",
tgt_key: str = "spatial",
src_landmarks_key: str | None = None,
tgt_landmarks_key: str | None = None,
config: STalignConfig | None = None,
inplace: bool = True,
) -> STalignResult:
"""
Align point coordinates stored on two AnnData objects.

This is the high-level experimental wrapper around
:func:`squidpy.experimental.tl.stalign_tools.stalign_points`.
It reads source and target coordinates from ``adata_src.obsm[src_key]``
and ``adata_tgt.obsm[tgt_key]``, optionally reads landmark coordinates
from ``.obsm`` or ``.uns`` via ``src_landmarks_key`` and
``tgt_landmarks_key``, and runs point-cloud registration.

Parameters
----------
adata_src
Source AnnData containing the point cloud to transform.
adata_tgt
Target AnnData containing the reference point cloud.
src_key
Key in ``adata_src.obsm`` holding source coordinates in ``(x, y)``
order.
tgt_key
Key in ``adata_tgt.obsm`` holding target coordinates in ``(x, y)``
order.
src_landmarks_key
Optional key in ``adata_src.obsm`` or ``adata_src.uns`` containing
source landmark coordinates in ``(x, y)`` order.
tgt_landmarks_key
Optional key in ``adata_tgt.obsm`` or ``adata_tgt.uns`` containing
target landmark coordinates in ``(x, y)`` order.
config
Optional STalign hyperparameter bundle. ``config.preprocess`` controls
rasterization and ``config.registration`` controls LDDMM fitting.
inplace
If ``True``, store a serializable summary of the fitted result under
``adata_src.uns["stalign"]``. The fitted result object is returned in
all cases.

Returns
-------
STalignResult
Fitted registration result. The returned object exposes
``transform_points(...)`` and ``transform_adata(...)`` helpers.
"""
source_points_xy = extract_points(adata_src, key=src_key)
target_points_xy = extract_points(adata_tgt, key=tgt_key)
source_points = source_points_xy[:, [1, 0]]
target_points = target_points_xy[:, [1, 0]]

if (src_landmarks_key is None) != (tgt_landmarks_key is None):
raise ValueError("Expected both landmark keys to be provided together.")

if src_landmarks_key is None:
landmarks_source = None
landmarks_target = None
else:
landmarks_source = extract_landmarks(adata_src, key=src_landmarks_key)[:, [1, 0]]
landmarks_target = extract_landmarks(adata_tgt, key=tgt_landmarks_key)[:, [1, 0]]

result = stalign_points(
source_points,
target_points,
config=config,
landmarks_source=landmarks_source,
landmarks_target=landmarks_target,
)
result.point_order = "xy"
result.aligned_points = np.asarray(result.aligned_points)[:, [1, 0]]

if inplace:
adata_src.uns["stalign"] = {
"result": _result_to_uns(result),
"src_key": src_key,
"tgt_key": tgt_key,
"src_landmarks_key": src_landmarks_key,
"tgt_landmarks_key": tgt_landmarks_key,
}

return result


def _result_to_uns(result: STalignResult) -> dict[str, object]:
return {
"affine": np.asarray(result.affine),
"velocity": np.asarray(result.velocity),
"velocity_grid": {
"row": np.asarray(result.velocity_grid[0]),
"col": np.asarray(result.velocity_grid[1]),
},
"aligned_points": np.asarray(result.aligned_points),
"point_order": result.point_order,
}
Loading
Loading