diff --git a/docs/api.md b/docs/api.md index fc9922a31..6b6a77a35 100644 --- a/docs/api.md +++ b/docs/api.md @@ -82,6 +82,7 @@ import squidpy as sq .. autosummary:: :toctree: api + tl.impute tl.sliding_window tl.var_by_distance ``` diff --git a/src/squidpy/tl/__init__.py b/src/squidpy/tl/__init__.py index 6d5abe98c..31d355fbe 100644 --- a/src/squidpy/tl/__init__.py +++ b/src/squidpy/tl/__init__.py @@ -2,5 +2,6 @@ from __future__ import annotations +from squidpy.tl._impute import impute from squidpy.tl._sliding_window import _calculate_window_corners, sliding_window from squidpy.tl._var_by_distance import var_by_distance diff --git a/src/squidpy/tl/_impute.py b/src/squidpy/tl/_impute.py new file mode 100644 index 000000000..e9d58b3f4 --- /dev/null +++ b/src/squidpy/tl/_impute.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from collections.abc import Sequence + +import pandas as pd +from anndata import AnnData + +from squidpy._docs import d +from squidpy._validators import assert_one_of + +from ._spage_impute import spage_impute + +__all__ = ["impute"] + +_ALLOWED_METHODS = ("spage",) + + +@d.dedent +def impute( + st_adata: AnnData, + sc_adata: AnnData, + *, + genes: Sequence[str] | None = None, + method: str = "spage", + n_pv: int = 30, + n_neighbors: int = 50, + cosine_threshold: float = 0.3, + use_raw: bool = False, + layer: str | None = None, + key_added: str = "spage", + n_jobs: int | None = None, + copy: bool = False, +) -> pd.DataFrame | None: + """ + Impute spatially unmeasured genes in spatial data using a selected method. + + Parameters + ---------- + st_adata + Spatial AnnData object. + sc_adata + scRNA-seq AnnData object. + genes + Genes to impute. If `None`, uses genes present in `sc_adata` but missing from `st_adata`. + method + Imputation method to use. Valid options are: + + - ``"spage"`` - SpaGE imputation. + n_pv + Number of principal vectors used for alignment. + n_neighbors + Number of nearest neighbors used for imputation. + cosine_threshold + Threshold on cosine similarity to select effective principal vectors. + use_raw + Whether to use `.raw` for expression values. + layer + Layer to use for expression values. + key_added + Key added to `.obsm` for the imputed genes. + n_jobs + Number of parallel jobs for nearest neighbors search. + copy + If `True`, return the imputed dataframe. Otherwise, save it to `.obsm[key_added]`. + Returns + ------- + If ``copy = True``, returns a :class:`pandas.DataFrame` with imputed values. + Otherwise, stores the result in :attr:`anndata.AnnData.obsm` ``[key_added]`` and returns `None`. + """ + assert_one_of(method, _ALLOWED_METHODS, name="method") + + if method == "spage": + return spage_impute( + st_adata, + sc_adata, + genes=genes, + n_pv=n_pv, + n_neighbors=n_neighbors, + cosine_threshold=cosine_threshold, + use_raw=use_raw, + layer=layer, + key_added=key_added, + n_jobs=n_jobs, + copy=copy, + ) + + raise NotImplementedError(f"Method `{method}` is not yet implemented.") diff --git a/src/squidpy/tl/_spage_impute.py b/src/squidpy/tl/_spage_impute.py new file mode 100644 index 000000000..ce7aaf855 --- /dev/null +++ b/src/squidpy/tl/_spage_impute.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +from collections.abc import Sequence + +import numba +import numpy as np +import pandas as pd +from anndata import AnnData +from scanpy import logging as logg +from scipy import linalg +from scipy.sparse import issparse, spmatrix +from sklearn.decomposition import PCA +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import StandardScaler + +from squidpy._docs import d +from squidpy._utils import NDArrayA +from squidpy.gr._utils import _extract_expression, _save_data + +__all__ = ["spage_impute"] + + +@d.dedent +def spage_impute( + st_adata: AnnData, + sc_adata: AnnData, + *, + genes: Sequence[str] | None = None, + n_pv: int = 30, + n_neighbors: int = 50, + cosine_threshold: float = 0.3, + use_raw: bool = False, + layer: str | None = None, + key_added: str = "spage", + n_jobs: int | None = None, + copy: bool = False, +) -> pd.DataFrame | None: + """ + Impute spatially unmeasured genes in spatial data using SpaGE. + + Parameters + ---------- + st_adata + Spatial AnnData object. + sc_adata + scRNA-seq AnnData object. + genes + Genes to impute. If `None`, uses genes present in `sc_adata` but missing from `st_adata`. + n_pv + Number of principal vectors used for alignment. + n_neighbors + Number of nearest neighbors used for imputation. + cosine_threshold + Threshold on cosine similarity to select effective principal vectors. + use_raw + Whether to use `.raw` for expression values. + layer + Layer to use for expression values. + key_added + Key added to `.obsm` for the imputed genes. + n_jobs + Number of parallel jobs for nearest neighbors search. + copy + If `True`, return the imputed dataframe. Otherwise, save it to `.obsm[key_added]`. + + Returns + ------- + If ``copy = True``, returns a :class:`pandas.DataFrame` with imputed values. + Otherwise, stores the result in :attr:`anndata.AnnData.obsm` ``[key_added]`` and returns `None`. + """ + start = logg.info("Running SpaGE imputation") + + if n_pv <= 0: + raise ValueError("`n_pv` must be positive.") + if n_neighbors <= 0: + raise ValueError("`n_neighbors` must be positive.") + if cosine_threshold < 0: + raise ValueError("`cosine_threshold` must be non-negative.") + + genes_to_predict = _resolve_genes_to_predict(sc_adata, genes) + shared_genes = _shared_genes(st_adata, sc_adata) + + if n_pv > len(shared_genes): + raise ValueError(f"`n_pv` must be <= number of shared genes ({len(shared_genes)}), found `{n_pv}`.") + + sc_shared, _ = _extract_expression(sc_adata, genes=shared_genes, use_raw=use_raw, layer=layer) + st_shared, _ = _extract_expression(st_adata, genes=shared_genes, use_raw=use_raw, layer=layer) + sc_target, _ = _extract_expression(sc_adata, genes=genes_to_predict, use_raw=use_raw, layer=layer) + + sc_shared = _standardize(sc_shared) + st_shared = _standardize(st_shared) + + source_components = _fit_components(sc_shared, n_pv) + target_components = _fit_components(st_shared, n_pv) + + source_components = _orthonormalize(source_components) + target_components = _orthonormalize(target_components) + + n_pv_eff = min(n_pv, source_components.shape[0], target_components.shape[0]) + if n_pv_eff <= 0: + raise ValueError("No principal vectors could be computed.") + + source_pv, target_pv, cosine = _compute_principal_vectors(source_components, target_components, n_pv_eff) + + effective_n_pv = int(np.sum(np.diag(cosine) > cosine_threshold)) + if effective_n_pv <= 0: + raise ValueError("No effective principal vectors found. Consider lowering `cosine_threshold` or `n_pv`.") + + S = source_pv[:effective_n_pv].T + + sc_proj = _dot(sc_shared, S) + st_proj = _dot(st_shared, S) + + n_neighbors = min(n_neighbors, sc_proj.shape[0]) + nn = NearestNeighbors( + n_neighbors=n_neighbors, + metric="cosine", + algorithm="auto", + n_jobs=n_jobs, + ) + nn.fit(sc_proj) + distances, indices = nn.kneighbors(st_proj, return_distance=True) + + weights, mask = _compute_weights(distances) + imputed = _impute_from_neighbors(weights, mask, indices, sc_target) + + result = pd.DataFrame(imputed, index=st_adata.obs_names, columns=genes_to_predict) + if copy: + logg.info("Finish", time=start) + return result + + _save_data(st_adata, attr="obsm", key=key_added, data=result, time=start) + return None + + +def _resolve_genes_to_predict( + sc_adata: AnnData, + genes: Sequence[str] | None, +) -> list[str]: + if genes is None: + genes_to_predict = [g for g in sc_adata.var_names] + else: + genes_to_predict = [g for g in genes if g in sc_adata.var_names] + missing = [g for g in genes if g not in sc_adata.var_names] + if missing: + raise ValueError(f"Genes not found in `sc_adata`: {missing}") + genes_to_predict = [g for g in genes_to_predict] + if not genes_to_predict: + raise ValueError("No genes to impute. Ensure `genes` are in `sc_adata`.") + return genes_to_predict + + +def _shared_genes(st_adata: AnnData, sc_adata: AnnData) -> list[str]: + shared = [g for g in st_adata.var_names if g in sc_adata.var_names] + if not shared: + raise ValueError("No shared genes between `st_adata` and `sc_adata`.") + return shared + + +def _standardize(X: NDArrayA | spmatrix) -> NDArrayA | spmatrix: + if issparse(X): + X = X.toarray() + scaler = StandardScaler(with_mean=True, copy=True) + return scaler.fit_transform(X) + + +def _fit_components(X: NDArrayA | spmatrix, n_components: int) -> NDArrayA: + reducer = PCA(n_components=n_components, svd_solver="arpack", random_state=0) + reducer.fit(X) + return reducer.components_ + + +def _orthonormalize(components: NDArrayA) -> NDArrayA: + return linalg.orth(components.T).T + + +def _compute_principal_vectors( + source_factors: NDArrayA, + target_factors: NDArrayA, + n_pv: int, +) -> tuple[NDArrayA, NDArrayA, NDArrayA]: + u, _, v = np.linalg.svd(source_factors @ target_factors.T, full_matrices=False) + source_pv = (u.T @ source_factors)[:n_pv] + target_pv = (v @ target_factors)[:n_pv] + source_pv = _normalize_rows(source_pv) + target_pv = _normalize_rows(target_pv) + cosine = source_pv @ target_pv.T + return source_pv, target_pv, cosine + + +def _normalize_rows(X: NDArrayA) -> NDArrayA: + denom = np.linalg.norm(X, axis=1, keepdims=True) + denom[denom == 0] = 1.0 + return X / denom + + +def _dot(X: NDArrayA | spmatrix, S: NDArrayA) -> NDArrayA: + return X @ S + + +@numba.njit(cache=True) +def _compute_weights(distances: NDArrayA, threshold: float = 1.0) -> tuple[NDArrayA, NDArrayA]: + n_obs, n_neighbors = distances.shape + weights = np.zeros((n_obs, n_neighbors), dtype=np.float64) + mask = distances < threshold + + for i in range(n_obs): + denom = 0.0 + count = 0 + for j in range(n_neighbors): + if mask[i, j]: + denom += distances[i, j] + count += 1 + if count <= 1 or denom == 0.0: + continue + for j in range(n_neighbors): + if mask[i, j]: + weights[i, j] = (1.0 - distances[i, j] / denom) / (count - 1) + + return weights, mask + + +def _impute_from_neighbors( + weights: NDArrayA, + mask: NDArrayA, + indices: NDArrayA, + y_train: NDArrayA | spmatrix, +) -> NDArrayA: + n_obs = weights.shape[0] + n_genes = y_train.shape[1] + result = np.zeros((n_obs, n_genes), dtype=np.float64) + + for i in range(n_obs): + valid = mask[i] + if not np.any(valid): + continue + w = weights[i, valid] + idx = indices[i, valid] + y_sub = y_train[idx] + imputed = w @ y_sub + result[i] = np.asarray(imputed).ravel() + + return result diff --git a/tests/tools/test_impute.py b/tests/tools/test_impute.py new file mode 100644 index 000000000..5c43cd57c --- /dev/null +++ b/tests/tools/test_impute.py @@ -0,0 +1,314 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData +from scipy.sparse import csr_matrix + +from squidpy.tl import impute + + +def _make_adata(n_obs: int, genes: list[str], rng: np.random.Generator) -> AnnData: + x = rng.normal(size=(n_obs, len(genes))) + return AnnData(x, var=pd.DataFrame(index=genes)) + + +class TestSpaGE: + def test_spage_impute_dense(self): + rng = np.random.default_rng(0) + sc_genes = [f"g{i}" for i in range(10)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(40, sc_genes, rng) + st_adata = _make_adata(20, st_genes, rng) + + res = impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=5, + key_added="spage", + ) + + assert res is None + assert "spage" in st_adata.obsm + + df = st_adata.obsm["spage"] + assert df.shape == (st_adata.n_obs, 10) + assert list(df.columns) == [f"g{i}" for i in range(10)] + assert df.index.equals(st_adata.obs_names) + + def test_spage_impute_sparse(self): + rng = np.random.default_rng(1) + sc_genes = [f"g{i}" for i in range(8)] + st_genes = [f"g{i}" for i in range(4)] + + sc_adata = _make_adata(30, sc_genes, rng) + st_adata = _make_adata(15, st_genes, rng) + sc_adata.X = csr_matrix(sc_adata.X) + st_adata.X = csr_matrix(st_adata.X) + + res = impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=4, + key_added="spage", + ) + + assert res is None + df = st_adata.obsm["spage"] + assert df.shape == (st_adata.n_obs, 8) + assert list(df.columns) == [f"g{i}" for i in range(8)] + + def test_spage_impute_in_place_write(self): + rng = np.random.default_rng(5) + sc_genes = [f"g{i}" for i in range(9)] + st_genes = [f"g{i}" for i in range(6)] + + sc_adata = _make_adata(25, sc_genes, rng) + st_adata = _make_adata(12, st_genes, rng) + + res = impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=4, + key_added="spage", + ) + + assert res is None + assert "spage" in st_adata.obsm + + def test_spage_impute_copy_returns_dataframe(self): + rng = np.random.default_rng(15) + sc_genes = [f"g{i}" for i in range(9)] + st_genes = [f"g{i}" for i in range(6)] + + sc_adata = _make_adata(25, sc_genes, rng) + st_adata = _make_adata(12, st_genes, rng) + + res = impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=4, + key_added="spage", + copy=True, + ) + + assert isinstance(res, pd.DataFrame) + assert res.shape == (st_adata.n_obs, 9) + assert "spage" not in st_adata.obsm + + def test_spage_impute_genes_subset_order(self): + rng = np.random.default_rng(6) + sc_genes = [f"g{i}" for i in range(10)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(30, sc_genes, rng) + st_adata = _make_adata(14, st_genes, rng) + + genes = ["g7", "g5"] + res = impute( + st_adata, + sc_adata, + genes=genes, + method="spage", + n_pv=3, + n_neighbors=5, + key_added="spage", + ) + + assert res is None + df = st_adata.obsm["spage"] + assert list(df.columns) == genes + + def test_spage_impute_cosine_threshold_too_strict(self): + rng = np.random.default_rng(7) + sc_genes = [f"g{i}" for i in range(8)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(22, sc_genes, rng) + st_adata = _make_adata(11, st_genes, rng) + + with pytest.raises(ValueError, match="No effective principal vectors"): + impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=4, + cosine_threshold=1.1, + ) + + def test_spage_impute_n_neighbors_clamped(self): + rng = np.random.default_rng(8) + sc_genes = [f"g{i}" for i in range(7)] + st_genes = [f"g{i}" for i in range(4)] + + sc_adata = _make_adata(6, sc_genes, rng) + st_adata = _make_adata(5, st_genes, rng) + + res = impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=50, + key_added="spage", + ) + + assert res is None + df = st_adata.obsm["spage"] + assert df.shape == (st_adata.n_obs, 7) + + def test_spage_impute_use_raw(self): + rng = np.random.default_rng(9) + sc_genes = [f"g{i}" for i in range(8)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(18, sc_genes, rng) + st_adata = _make_adata(12, st_genes, rng) + + sc_adata.raw = sc_adata.copy() + st_adata.raw = st_adata.copy() + + res = impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=4, + key_added="spage", + use_raw=True, + ) + + assert res is None + assert "spage" in st_adata.obsm + + def test_spage_impute_layer(self): + rng = np.random.default_rng(10) + sc_genes = [f"g{i}" for i in range(8)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(18, sc_genes, rng) + st_adata = _make_adata(12, st_genes, rng) + + sc_adata.layers["counts"] = sc_adata.X.copy() + st_adata.layers["counts"] = st_adata.X.copy() + + res = impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=4, + key_added="spage", + layer="counts", + ) + + assert res is None + assert "spage" in st_adata.obsm + + def test_spage_impute_shared_genes_are_kept(self): + rng = np.random.default_rng(16) + sc_genes = [f"g{i}" for i in range(10)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(40, sc_genes, rng) + st_adata = _make_adata(20, st_genes, rng) + + impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=5, + key_added="spage", + ) + + df = st_adata.obsm["spage"] + assert df.shape == (st_adata.n_obs, 10) + assert all(gene in df.columns for gene in st_genes) + + def test_spage_impute_invalid_genes(self): + rng = np.random.default_rng(2) + sc_genes = [f"g{i}" for i in range(6)] + st_genes = [f"g{i}" for i in range(3)] + + sc_adata = _make_adata(20, sc_genes, rng) + st_adata = _make_adata(10, st_genes, rng) + + with pytest.raises(ValueError, match="Genes not found in `sc_adata`"): + impute(st_adata, sc_adata, method="spage", genes=["g4", "gX"], n_pv=2, n_neighbors=3) + + def test_spage_impute_no_shared_genes(self): + rng = np.random.default_rng(3) + sc_genes = [f"h{i}" for i in range(6)] + st_genes = [f"g{i}" for i in range(3)] + + sc_adata = _make_adata(20, sc_genes, rng) + st_adata = _make_adata(10, st_genes, rng) + + with pytest.raises(ValueError, match="No shared genes"): + impute(st_adata, sc_adata, method="spage", n_pv=2, n_neighbors=3) + + def test_spage_impute_no_genes_to_impute(self): + rng = np.random.default_rng(11) + sc_genes = [f"g{i}" for i in range(5)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(20, sc_genes, rng) + st_adata = _make_adata(10, st_genes, rng) + + with pytest.raises(ValueError, match="No genes to impute"): + impute(st_adata, sc_adata, method="spage", genes=[], n_pv=2, n_neighbors=3) + + def test_spage_impute_n_pv_too_large(self): + rng = np.random.default_rng(4) + sc_genes = [f"g{i}" for i in range(7)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(20, sc_genes, rng) + st_adata = _make_adata(10, st_genes, rng) + + with pytest.raises(ValueError, match="`n_pv` must be <= number of shared genes"): + impute(st_adata, sc_adata, method="spage", n_pv=10, n_neighbors=3) + + +class TestImputeDispatch: + def test_invalid_method_raises(self): + rng = np.random.default_rng(13) + sc_genes = [f"g{i}" for i in range(6)] + st_genes = [f"g{i}" for i in range(4)] + + sc_adata = _make_adata(16, sc_genes, rng) + st_adata = _make_adata(8, st_genes, rng) + + with pytest.raises(ValueError, match="one of"): + impute(st_adata, sc_adata, method="tangram") + + def test_spage_args_are_supported(self): + rng = np.random.default_rng(14) + sc_genes = [f"g{i}" for i in range(8)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(20, sc_genes, rng) + st_adata = _make_adata(10, st_genes, rng) + + res = impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=4, + ) + + assert res is None + assert "spage" in st_adata.obsm