From 6bce2c06bb4457ccf26ef8178866a8f107ddeb82 Mon Sep 17 00:00:00 2001 From: MDLDan Date: Thu, 16 Apr 2026 09:22:13 +0000 Subject: [PATCH 01/13] add impute functions --- src/squidpy/tl/__init__.py | 1 + src/squidpy/tl/_impute.py | 86 +++++++++++ src/squidpy/tl/_spage_impute.py | 251 ++++++++++++++++++++++++++++++++ 3 files changed, 338 insertions(+) create mode 100644 src/squidpy/tl/_impute.py create mode 100644 src/squidpy/tl/_spage_impute.py diff --git a/src/squidpy/tl/__init__.py b/src/squidpy/tl/__init__.py index 6d5abe98c..f2fa10858 100644 --- a/src/squidpy/tl/__init__.py +++ b/src/squidpy/tl/__init__.py @@ -4,3 +4,4 @@ from squidpy.tl._sliding_window import _calculate_window_corners, sliding_window from squidpy.tl._var_by_distance import var_by_distance +from squidpy.tl._impute import impute diff --git a/src/squidpy/tl/_impute.py b/src/squidpy/tl/_impute.py new file mode 100644 index 000000000..c40f4f2b9 --- /dev/null +++ b/src/squidpy/tl/_impute.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Any + +from anndata import AnnData + +from squidpy._docs import d +from squidpy._validators import assert_one_of + +from ._spage_impute import SpaGEParams, spage_impute + +__all__ = ["impute"] + +_ALLOWED_METHODS = ("spage",) + + +@d.dedent +def impute( + st_adata: AnnData, + sc_adata: AnnData, + *, + genes: Sequence[str] | None = None, + method: str = "spage", + method_params: SpaGEParams | Mapping[str, Any] | None = None, + n_pv: int = 30, + n_neighbors: int = 50, + cosine_threshold: float = 0.3, + use_raw: bool = False, + layer: str | None = None, + key_added: str = "spage", + n_jobs: int | None = None, + copy: bool = False, +) -> AnnData: + """ + Impute spatially unmeasured genes in spatial data using a selected method. + + Parameters + ---------- + st_adata + Spatial AnnData object. + sc_adata + scRNA-seq AnnData object. + genes + Genes to impute. If `None`, uses genes present in `sc_adata` but missing from `st_adata`. + method + Imputation method to use. Valid options are: + + - ``"spage"`` - SpaGE imputation. + method_params + Optional method-specific parameters. For ``method="spage"``, provide :class:`SpaGEParams` + or a mapping with matching field names. + key_added + Key added to `.obsm` for the imputed genes. + copy + Whether to return a copy of `st_adata`. + + Returns + ------- + AnnData with imputed genes stored in `.obsm[key_added]`. + """ + assert_one_of(method, _ALLOWED_METHODS, name="method") + + if method == "spage": + if method_params is None: + method_params = SpaGEParams( + n_pv=n_pv, + n_neighbors=n_neighbors, + cosine_threshold=cosine_threshold, + use_raw=use_raw, + layer=layer, + n_jobs=n_jobs, + ) + elif isinstance(method_params, Mapping): + method_params = SpaGEParams.from_mapping(method_params) + + return spage_impute( + st_adata, + sc_adata, + genes=genes, + params=method_params, + key_added=key_added, + copy=copy, + ) + + raise NotImplementedError(f"Method `{method}` is not yet implemented.") \ No newline at end of file diff --git a/src/squidpy/tl/_spage_impute.py b/src/squidpy/tl/_spage_impute.py new file mode 100644 index 000000000..e937b156b --- /dev/null +++ b/src/squidpy/tl/_spage_impute.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass + +import numba +import numpy as np +import pandas as pd +from anndata import AnnData +from scanpy import logging as logg +from scipy import linalg +from scipy.sparse import issparse, spmatrix +from sklearn.decomposition import PCA +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import StandardScaler + +from squidpy._docs import d +from squidpy._utils import NDArrayA +from squidpy.gr._utils import _extract_expression, _save_data + +__all__ = ["SpaGEParams", "spage_impute"] + + +@dataclass(slots=True, frozen=True) +class SpaGEParams: + n_pv: int = 30 + n_neighbors: int = 50 + cosine_threshold: float = 0.3 + use_raw: bool = False + layer: str | None = None + n_jobs: int | None = None + + @classmethod + def from_mapping(cls, params: Mapping[str, object]) -> SpaGEParams: + return cls(**params) + + +@d.dedent +def spage_impute( + st_adata: AnnData, + sc_adata: AnnData, + *, + genes: Sequence[str] | None = None, + params: SpaGEParams | Mapping[str, object] | None = None, + key_added: str = "spage", + copy: bool = False, +) -> AnnData: + """ + Impute spatially unmeasured genes in spatial data using SpaGE. + + Parameters + ---------- + st_adata + Spatial AnnData object. + sc_adata + scRNA-seq AnnData object. + genes + Genes to impute. If `None`, uses genes present in `sc_adata` but missing from `st_adata`. + params + SpaGE-specific parameters. + key_added + Key added to `.obsm` for the imputed genes. + copy + Whether to return a copy of `st_adata`. + + Returns + ------- + AnnData with imputed genes stored in `.obsm[key_added]`. + """ + start = logg.info("Running SpaGE imputation") + + if copy: + st_adata = st_adata.copy() + + if params is None: + params = SpaGEParams() + elif isinstance(params, Mapping): + params = SpaGEParams.from_mapping(params) + + if params.n_pv <= 0: + raise ValueError("`n_pv` must be positive.") + if params.n_neighbors <= 0: + raise ValueError("`n_neighbors` must be positive.") + if params.cosine_threshold < 0: + raise ValueError("`cosine_threshold` must be non-negative.") + + genes_to_predict = _resolve_genes_to_predict(st_adata, sc_adata, genes) + shared_genes = _shared_genes(st_adata, sc_adata) + + if params.n_pv > len(shared_genes): + raise ValueError( + f"`n_pv` must be <= number of shared genes ({len(shared_genes)}), found `{params.n_pv}`." + ) + + sc_shared, _ = _extract_expression(sc_adata, genes=shared_genes, use_raw=params.use_raw, layer=params.layer) + st_shared, _ = _extract_expression(st_adata, genes=shared_genes, use_raw=params.use_raw, layer=params.layer) + sc_target, _ = _extract_expression(sc_adata, genes=genes_to_predict, use_raw=params.use_raw, layer=params.layer) + + sc_shared = _standardize(sc_shared) + st_shared = _standardize(st_shared) + + source_components = _fit_components(sc_shared, params.n_pv) + target_components = _fit_components(st_shared, params.n_pv) + + source_components = _orthonormalize(source_components) + target_components = _orthonormalize(target_components) + + n_pv_eff = min(params.n_pv, source_components.shape[0], target_components.shape[0]) + if n_pv_eff <= 0: + raise ValueError("No principal vectors could be computed.") + + source_pv, target_pv, cosine = _compute_principal_vectors(source_components, target_components, n_pv_eff) + + effective_n_pv = int(np.sum(np.diag(cosine) > params.cosine_threshold)) + if effective_n_pv <= 0: + raise ValueError( + "No effective principal vectors found. Consider lowering `cosine_threshold` or `n_pv`." + ) + + S = source_pv[:effective_n_pv].T + + sc_proj = _dot(sc_shared, S) + st_proj = _dot(st_shared, S) + + n_neighbors = min(params.n_neighbors, sc_proj.shape[0]) + nn = NearestNeighbors( + n_neighbors=n_neighbors, + metric="cosine", + algorithm="auto", + n_jobs=params.n_jobs, + ) + nn.fit(sc_proj) + distances, indices = nn.kneighbors(st_proj, return_distance=True) + + weights, mask = _compute_weights(distances) + imputed = _impute_from_neighbors(weights, mask, indices, sc_target) + + result = pd.DataFrame(imputed, index=st_adata.obs_names, columns=genes_to_predict) + _save_data(st_adata, attr="obsm", key=key_added, data=result, time=start) + return st_adata + + +def _resolve_genes_to_predict( + st_adata: AnnData, + sc_adata: AnnData, + genes: Sequence[str] | None, +) -> list[str]: + if genes is None: + genes_to_predict = [g for g in sc_adata.var_names if g not in st_adata.var_names] + else: + genes_to_predict = [g for g in genes if g in sc_adata.var_names] + missing = [g for g in genes if g not in sc_adata.var_names] + if missing: + raise ValueError(f"Genes not found in `sc_adata`: {missing}") + genes_to_predict = [g for g in genes_to_predict if g not in st_adata.var_names] + if not genes_to_predict: + raise ValueError("No genes to impute. Ensure `genes` are in `sc_adata` and absent from `st_adata`.") + return genes_to_predict + + +def _shared_genes(st_adata: AnnData, sc_adata: AnnData) -> list[str]: + shared = [g for g in st_adata.var_names if g in sc_adata.var_names] + if not shared: + raise ValueError("No shared genes between `st_adata` and `sc_adata`.") + return shared + + +def _standardize(X: NDArrayA | spmatrix) -> NDArrayA | spmatrix: + if issparse(X): + X = X.toarray() + scaler = StandardScaler(with_mean=True, copy=True) + return scaler.fit_transform(X) + + +def _fit_components(X: NDArrayA | spmatrix, n_components: int) -> NDArrayA: + reducer = PCA(n_components=n_components, svd_solver="arpack", random_state=0) + reducer.fit(X) + return reducer.components_ + + +def _orthonormalize(components: NDArrayA) -> NDArrayA: + return linalg.orth(components.T).T + + +def _compute_principal_vectors( + source_factors: NDArrayA, + target_factors: NDArrayA, + n_pv: int, +) -> tuple[NDArrayA, NDArrayA, NDArrayA]: + u, _, v = np.linalg.svd(source_factors @ target_factors.T, full_matrices=False) + source_pv = (u.T @ source_factors)[:n_pv] + target_pv = (v @ target_factors)[:n_pv] + source_pv = _normalize_rows(source_pv) + target_pv = _normalize_rows(target_pv) + cosine = source_pv @ target_pv.T + return source_pv, target_pv, cosine + + +def _normalize_rows(X: NDArrayA) -> NDArrayA: + denom = np.linalg.norm(X, axis=1, keepdims=True) + denom[denom == 0] = 1.0 + return X / denom + + +def _dot(X: NDArrayA | spmatrix, S: NDArrayA) -> NDArrayA: + return X @ S + + +@numba.njit(cache=True) +def _compute_weights(distances: NDArrayA, threshold: float = 1.0) -> tuple[NDArrayA, NDArrayA]: + n_obs, n_neighbors = distances.shape + weights = np.zeros((n_obs, n_neighbors), dtype=np.float64) + mask = distances < threshold + + for i in range(n_obs): + denom = 0.0 + count = 0 + for j in range(n_neighbors): + if mask[i, j]: + denom += distances[i, j] + count += 1 + if count <= 1 or denom == 0.0: + continue + for j in range(n_neighbors): + if mask[i, j]: + weights[i, j] = (1.0 - distances[i, j] / denom) / (count - 1) + + return weights, mask + + +def _impute_from_neighbors( + weights: NDArrayA, + mask: NDArrayA, + indices: NDArrayA, + y_train: NDArrayA | spmatrix, +) -> NDArrayA: + n_obs = weights.shape[0] + n_genes = y_train.shape[1] + result = np.zeros((n_obs, n_genes), dtype=np.float64) + + for i in range(n_obs): + valid = mask[i] + if not np.any(valid): + continue + w = weights[i, valid] + idx = indices[i, valid] + y_sub = y_train[idx] + imputed = w @ y_sub + result[i] = np.asarray(imputed).ravel() + + return result \ No newline at end of file From 622338357321f701e470b55869d3c4c5730d1694 Mon Sep 17 00:00:00 2001 From: MDLDan Date: Thu, 16 Apr 2026 09:22:27 +0000 Subject: [PATCH 02/13] add impute testing --- tests/tools/test_impute.py | 272 +++++++++++++++++++++++++++++++++++++ 1 file changed, 272 insertions(+) create mode 100644 tests/tools/test_impute.py diff --git a/tests/tools/test_impute.py b/tests/tools/test_impute.py new file mode 100644 index 000000000..caf499b5d --- /dev/null +++ b/tests/tools/test_impute.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData +from scipy.sparse import csr_matrix + +from squidpy.tl import impute +from squidpy.tl._spage_impute import SpaGEParams + + +def _make_adata(n_obs: int, genes: list[str], rng: np.random.Generator) -> AnnData: + x = rng.normal(size=(n_obs, len(genes))) + return AnnData(x, var=pd.DataFrame(index=genes)) + + +class TestSpaGE: + def test_spage_impute_dense_copy(self): + rng = np.random.default_rng(0) + sc_genes = [f"g{i}" for i in range(10)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(40, sc_genes, rng) + st_adata = _make_adata(20, st_genes, rng) + + res = impute( + st_adata, + sc_adata, + method="spage", + method_params=SpaGEParams(n_pv=3, n_neighbors=5), + key_added="spage", + copy=True, + ) + + assert "spage" in res.obsm + assert "spage" not in st_adata.obsm + + df = res.obsm["spage"] + assert df.shape == (st_adata.n_obs, 5) + assert list(df.columns) == [f"g{i}" for i in range(5, 10)] + assert df.index.equals(st_adata.obs_names) + + def test_spage_impute_sparse(self): + rng = np.random.default_rng(1) + sc_genes = [f"g{i}" for i in range(8)] + st_genes = [f"g{i}" for i in range(4)] + + sc_adata = _make_adata(30, sc_genes, rng) + st_adata = _make_adata(15, st_genes, rng) + sc_adata.X = csr_matrix(sc_adata.X) + st_adata.X = csr_matrix(st_adata.X) + + res = impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=4, + key_added="spage", + copy=True, + ) + + df = res.obsm["spage"] + assert df.shape == (st_adata.n_obs, 4) + assert list(df.columns) == [f"g{i}" for i in range(4, 8)] + + def test_spage_impute_copy_false(self): + rng = np.random.default_rng(5) + sc_genes = [f"g{i}" for i in range(9)] + st_genes = [f"g{i}" for i in range(6)] + + sc_adata = _make_adata(25, sc_genes, rng) + st_adata = _make_adata(12, st_genes, rng) + + res = impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=4, + key_added="spage", + copy=False, + ) + + assert res is st_adata + assert "spage" in st_adata.obsm + + def test_spage_impute_genes_subset_order(self): + rng = np.random.default_rng(6) + sc_genes = [f"g{i}" for i in range(10)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(30, sc_genes, rng) + st_adata = _make_adata(14, st_genes, rng) + + genes = ["g7", "g5"] + res = impute( + st_adata, + sc_adata, + genes=genes, + method="spage", + n_pv=3, + n_neighbors=5, + key_added="spage", + copy=True, + ) + + df = res.obsm["spage"] + assert list(df.columns) == genes + + def test_spage_impute_cosine_threshold_too_strict(self): + rng = np.random.default_rng(7) + sc_genes = [f"g{i}" for i in range(8)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(22, sc_genes, rng) + st_adata = _make_adata(11, st_genes, rng) + + with pytest.raises(ValueError, match="No effective principal vectors"): + impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=4, + cosine_threshold=1.1, + ) + + def test_spage_impute_n_neighbors_clamped(self): + rng = np.random.default_rng(8) + sc_genes = [f"g{i}" for i in range(7)] + st_genes = [f"g{i}" for i in range(4)] + + sc_adata = _make_adata(6, sc_genes, rng) + st_adata = _make_adata(5, st_genes, rng) + + res = impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=50, + key_added="spage", + copy=True, + ) + + df = res.obsm["spage"] + assert df.shape == (st_adata.n_obs, 3) + + def test_spage_impute_use_raw(self): + rng = np.random.default_rng(9) + sc_genes = [f"g{i}" for i in range(8)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(18, sc_genes, rng) + st_adata = _make_adata(12, st_genes, rng) + + sc_adata.raw = sc_adata.copy() + st_adata.raw = st_adata.copy() + + res = impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=4, + key_added="spage", + use_raw=True, + copy=True, + ) + + assert "spage" in res.obsm + + def test_spage_impute_layer(self): + rng = np.random.default_rng(10) + sc_genes = [f"g{i}" for i in range(8)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(18, sc_genes, rng) + st_adata = _make_adata(12, st_genes, rng) + + sc_adata.layers["counts"] = sc_adata.X.copy() + st_adata.layers["counts"] = st_adata.X.copy() + + res = impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=4, + key_added="spage", + layer="counts", + copy=True, + ) + + assert "spage" in res.obsm + + def test_spage_impute_invalid_genes(self): + rng = np.random.default_rng(2) + sc_genes = [f"g{i}" for i in range(6)] + st_genes = [f"g{i}" for i in range(3)] + + sc_adata = _make_adata(20, sc_genes, rng) + st_adata = _make_adata(10, st_genes, rng) + + with pytest.raises(ValueError, match="Genes not found in `sc_adata`"): + impute(st_adata, sc_adata, method="spage", genes=["g4", "gX"], n_pv=2, n_neighbors=3) + + def test_spage_impute_no_shared_genes(self): + rng = np.random.default_rng(3) + sc_genes = [f"h{i}" for i in range(6)] + st_genes = [f"g{i}" for i in range(3)] + + sc_adata = _make_adata(20, sc_genes, rng) + st_adata = _make_adata(10, st_genes, rng) + + with pytest.raises(ValueError, match="No shared genes"): + impute(st_adata, sc_adata, method="spage", n_pv=2, n_neighbors=3) + + def test_spage_impute_no_genes_to_impute(self): + rng = np.random.default_rng(11) + sc_genes = [f"g{i}" for i in range(5)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(20, sc_genes, rng) + st_adata = _make_adata(10, st_genes, rng) + + with pytest.raises(ValueError, match="No genes to impute"): + impute(st_adata, sc_adata, method="spage", n_pv=2, n_neighbors=3) + + def test_spage_impute_n_pv_too_large(self): + rng = np.random.default_rng(4) + sc_genes = [f"g{i}" for i in range(7)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(20, sc_genes, rng) + st_adata = _make_adata(10, st_genes, rng) + + with pytest.raises(ValueError, match="`n_pv` must be <= number of shared genes"): + impute(st_adata, sc_adata, method="spage", n_pv=10, n_neighbors=3) + + +class TestImputeDispatch: + def test_invalid_method_raises(self): + rng = np.random.default_rng(13) + sc_genes = [f"g{i}" for i in range(6)] + st_genes = [f"g{i}" for i in range(4)] + + sc_adata = _make_adata(16, sc_genes, rng) + st_adata = _make_adata(8, st_genes, rng) + + with pytest.raises(ValueError, match="one of"): + impute(st_adata, sc_adata, method="tangram") + + def test_method_params_object_is_supported(self): + rng = np.random.default_rng(14) + sc_genes = [f"g{i}" for i in range(8)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(20, sc_genes, rng) + st_adata = _make_adata(10, st_genes, rng) + + res = impute( + st_adata, + sc_adata, + method="spage", + method_params=SpaGEParams(n_pv=3, n_neighbors=4), + copy=True, + ) + + assert "spage" in res.obsm From 2f36d23f63fe2f3c2bc12e26111987941989ee7a Mon Sep 17 00:00:00 2001 From: MDLDan Date: Thu, 16 Apr 2026 09:23:02 +0000 Subject: [PATCH 03/13] include impute func in tools --- docs/api.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/api.md b/docs/api.md index fc9922a31..6b6a77a35 100644 --- a/docs/api.md +++ b/docs/api.md @@ -82,6 +82,7 @@ import squidpy as sq .. autosummary:: :toctree: api + tl.impute tl.sliding_window tl.var_by_distance ``` From 66b5716804ba58de5a3abf78a1fcca75697ea6d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Apr 2026 09:40:19 +0000 Subject: [PATCH 04/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/tl/__init__.py | 2 +- src/squidpy/tl/_impute.py | 2 +- src/squidpy/tl/_spage_impute.py | 10 +++------- tests/tools/test_impute.py | 2 +- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/squidpy/tl/__init__.py b/src/squidpy/tl/__init__.py index f2fa10858..31d355fbe 100644 --- a/src/squidpy/tl/__init__.py +++ b/src/squidpy/tl/__init__.py @@ -2,6 +2,6 @@ from __future__ import annotations +from squidpy.tl._impute import impute from squidpy.tl._sliding_window import _calculate_window_corners, sliding_window from squidpy.tl._var_by_distance import var_by_distance -from squidpy.tl._impute import impute diff --git a/src/squidpy/tl/_impute.py b/src/squidpy/tl/_impute.py index c40f4f2b9..657a556d0 100644 --- a/src/squidpy/tl/_impute.py +++ b/src/squidpy/tl/_impute.py @@ -83,4 +83,4 @@ def impute( copy=copy, ) - raise NotImplementedError(f"Method `{method}` is not yet implemented.") \ No newline at end of file + raise NotImplementedError(f"Method `{method}` is not yet implemented.") diff --git a/src/squidpy/tl/_spage_impute.py b/src/squidpy/tl/_spage_impute.py index e937b156b..bccd43f24 100644 --- a/src/squidpy/tl/_spage_impute.py +++ b/src/squidpy/tl/_spage_impute.py @@ -88,9 +88,7 @@ def spage_impute( shared_genes = _shared_genes(st_adata, sc_adata) if params.n_pv > len(shared_genes): - raise ValueError( - f"`n_pv` must be <= number of shared genes ({len(shared_genes)}), found `{params.n_pv}`." - ) + raise ValueError(f"`n_pv` must be <= number of shared genes ({len(shared_genes)}), found `{params.n_pv}`.") sc_shared, _ = _extract_expression(sc_adata, genes=shared_genes, use_raw=params.use_raw, layer=params.layer) st_shared, _ = _extract_expression(st_adata, genes=shared_genes, use_raw=params.use_raw, layer=params.layer) @@ -113,9 +111,7 @@ def spage_impute( effective_n_pv = int(np.sum(np.diag(cosine) > params.cosine_threshold)) if effective_n_pv <= 0: - raise ValueError( - "No effective principal vectors found. Consider lowering `cosine_threshold` or `n_pv`." - ) + raise ValueError("No effective principal vectors found. Consider lowering `cosine_threshold` or `n_pv`.") S = source_pv[:effective_n_pv].T @@ -248,4 +244,4 @@ def _impute_from_neighbors( imputed = w @ y_sub result[i] = np.asarray(imputed).ravel() - return result \ No newline at end of file + return result diff --git a/tests/tools/test_impute.py b/tests/tools/test_impute.py index caf499b5d..3a471b38b 100644 --- a/tests/tools/test_impute.py +++ b/tests/tools/test_impute.py @@ -95,7 +95,7 @@ def test_spage_impute_genes_subset_order(self): st_adata = _make_adata(14, st_genes, rng) genes = ["g7", "g5"] - res = impute( + res = impute( st_adata, sc_adata, genes=genes, From 47dfa5a457c25e6d93e13ab2837274718da17d41 Mon Sep 17 00:00:00 2001 From: MDLDan Date: Thu, 16 Apr 2026 10:50:04 +0000 Subject: [PATCH 05/13] remove copy --- src/squidpy/tl/_impute.py | 5 ----- src/squidpy/tl/_spage_impute.py | 4 ---- 2 files changed, 9 deletions(-) diff --git a/src/squidpy/tl/_impute.py b/src/squidpy/tl/_impute.py index 657a556d0..ff7910f7a 100644 --- a/src/squidpy/tl/_impute.py +++ b/src/squidpy/tl/_impute.py @@ -30,7 +30,6 @@ def impute( layer: str | None = None, key_added: str = "spage", n_jobs: int | None = None, - copy: bool = False, ) -> AnnData: """ Impute spatially unmeasured genes in spatial data using a selected method. @@ -52,9 +51,6 @@ def impute( or a mapping with matching field names. key_added Key added to `.obsm` for the imputed genes. - copy - Whether to return a copy of `st_adata`. - Returns ------- AnnData with imputed genes stored in `.obsm[key_added]`. @@ -80,7 +76,6 @@ def impute( genes=genes, params=method_params, key_added=key_added, - copy=copy, ) raise NotImplementedError(f"Method `{method}` is not yet implemented.") diff --git a/src/squidpy/tl/_spage_impute.py b/src/squidpy/tl/_spage_impute.py index bccd43f24..b349eba14 100644 --- a/src/squidpy/tl/_spage_impute.py +++ b/src/squidpy/tl/_spage_impute.py @@ -43,7 +43,6 @@ def spage_impute( genes: Sequence[str] | None = None, params: SpaGEParams | Mapping[str, object] | None = None, key_added: str = "spage", - copy: bool = False, ) -> AnnData: """ Impute spatially unmeasured genes in spatial data using SpaGE. @@ -69,9 +68,6 @@ def spage_impute( """ start = logg.info("Running SpaGE imputation") - if copy: - st_adata = st_adata.copy() - if params is None: params = SpaGEParams() elif isinstance(params, Mapping): From d9b92e3db15147ba450933739711beedbcb6f0c6 Mon Sep 17 00:00:00 2001 From: MDLDan Date: Thu, 16 Apr 2026 11:00:38 +0000 Subject: [PATCH 06/13] remove spageparams --- src/squidpy/tl/_impute.py | 44 ++++++++++--------- src/squidpy/tl/_spage_impute.py | 75 ++++++++++++++++----------------- tests/tools/test_impute.py | 9 ++-- 3 files changed, 66 insertions(+), 62 deletions(-) diff --git a/src/squidpy/tl/_impute.py b/src/squidpy/tl/_impute.py index ff7910f7a..25443c372 100644 --- a/src/squidpy/tl/_impute.py +++ b/src/squidpy/tl/_impute.py @@ -1,14 +1,13 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence -from typing import Any +from collections.abc import Sequence from anndata import AnnData from squidpy._docs import d from squidpy._validators import assert_one_of -from ._spage_impute import SpaGEParams, spage_impute +from ._spage_impute import spage_impute __all__ = ["impute"] @@ -22,7 +21,6 @@ def impute( *, genes: Sequence[str] | None = None, method: str = "spage", - method_params: SpaGEParams | Mapping[str, Any] | None = None, n_pv: int = 30, n_neighbors: int = 50, cosine_threshold: float = 0.3, @@ -30,6 +28,7 @@ def impute( layer: str | None = None, key_added: str = "spage", n_jobs: int | None = None, + copy: bool = False, ) -> AnnData: """ Impute spatially unmeasured genes in spatial data using a selected method. @@ -46,11 +45,22 @@ def impute( Imputation method to use. Valid options are: - ``"spage"`` - SpaGE imputation. - method_params - Optional method-specific parameters. For ``method="spage"``, provide :class:`SpaGEParams` - or a mapping with matching field names. + n_pv + Number of principal vectors used for alignment. + n_neighbors + Number of nearest neighbors used for imputation. + cosine_threshold + Threshold on cosine similarity to select effective principal vectors. + use_raw + Whether to use `.raw` for expression values. + layer + Layer to use for expression values. key_added Key added to `.obsm` for the imputed genes. + n_jobs + Number of parallel jobs for nearest neighbors search. + copy + Whether to return a copy of `st_adata`. Returns ------- AnnData with imputed genes stored in `.obsm[key_added]`. @@ -58,24 +68,18 @@ def impute( assert_one_of(method, _ALLOWED_METHODS, name="method") if method == "spage": - if method_params is None: - method_params = SpaGEParams( - n_pv=n_pv, - n_neighbors=n_neighbors, - cosine_threshold=cosine_threshold, - use_raw=use_raw, - layer=layer, - n_jobs=n_jobs, - ) - elif isinstance(method_params, Mapping): - method_params = SpaGEParams.from_mapping(method_params) - return spage_impute( st_adata, sc_adata, genes=genes, - params=method_params, + n_pv=n_pv, + n_neighbors=n_neighbors, + cosine_threshold=cosine_threshold, + use_raw=use_raw, + layer=layer, key_added=key_added, + n_jobs=n_jobs, + copy=copy, ) raise NotImplementedError(f"Method `{method}` is not yet implemented.") diff --git a/src/squidpy/tl/_spage_impute.py b/src/squidpy/tl/_spage_impute.py index b349eba14..78d05f044 100644 --- a/src/squidpy/tl/_spage_impute.py +++ b/src/squidpy/tl/_spage_impute.py @@ -1,7 +1,6 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence -from dataclasses import dataclass +from collections.abc import Sequence import numba import numpy as np @@ -18,21 +17,7 @@ from squidpy._utils import NDArrayA from squidpy.gr._utils import _extract_expression, _save_data -__all__ = ["SpaGEParams", "spage_impute"] - - -@dataclass(slots=True, frozen=True) -class SpaGEParams: - n_pv: int = 30 - n_neighbors: int = 50 - cosine_threshold: float = 0.3 - use_raw: bool = False - layer: str | None = None - n_jobs: int | None = None - - @classmethod - def from_mapping(cls, params: Mapping[str, object]) -> SpaGEParams: - return cls(**params) +__all__ = ["spage_impute"] @d.dedent @@ -41,8 +26,14 @@ def spage_impute( sc_adata: AnnData, *, genes: Sequence[str] | None = None, - params: SpaGEParams | Mapping[str, object] | None = None, + n_pv: int = 30, + n_neighbors: int = 50, + cosine_threshold: float = 0.3, + use_raw: bool = False, + layer: str | None = None, key_added: str = "spage", + n_jobs: int | None = None, + copy: bool = False, ) -> AnnData: """ Impute spatially unmeasured genes in spatial data using SpaGE. @@ -55,10 +46,20 @@ def spage_impute( scRNA-seq AnnData object. genes Genes to impute. If `None`, uses genes present in `sc_adata` but missing from `st_adata`. - params - SpaGE-specific parameters. + n_pv + Number of principal vectors used for alignment. + n_neighbors + Number of nearest neighbors used for imputation. + cosine_threshold + Threshold on cosine similarity to select effective principal vectors. + use_raw + Whether to use `.raw` for expression values. + layer + Layer to use for expression values. key_added Key added to `.obsm` for the imputed genes. + n_jobs + Number of parallel jobs for nearest neighbors search. copy Whether to return a copy of `st_adata`. @@ -68,44 +69,42 @@ def spage_impute( """ start = logg.info("Running SpaGE imputation") - if params is None: - params = SpaGEParams() - elif isinstance(params, Mapping): - params = SpaGEParams.from_mapping(params) + if copy: + st_adata = st_adata.copy() - if params.n_pv <= 0: + if n_pv <= 0: raise ValueError("`n_pv` must be positive.") - if params.n_neighbors <= 0: + if n_neighbors <= 0: raise ValueError("`n_neighbors` must be positive.") - if params.cosine_threshold < 0: + if cosine_threshold < 0: raise ValueError("`cosine_threshold` must be non-negative.") genes_to_predict = _resolve_genes_to_predict(st_adata, sc_adata, genes) shared_genes = _shared_genes(st_adata, sc_adata) - if params.n_pv > len(shared_genes): - raise ValueError(f"`n_pv` must be <= number of shared genes ({len(shared_genes)}), found `{params.n_pv}`.") + if n_pv > len(shared_genes): + raise ValueError(f"`n_pv` must be <= number of shared genes ({len(shared_genes)}), found `{n_pv}`.") - sc_shared, _ = _extract_expression(sc_adata, genes=shared_genes, use_raw=params.use_raw, layer=params.layer) - st_shared, _ = _extract_expression(st_adata, genes=shared_genes, use_raw=params.use_raw, layer=params.layer) - sc_target, _ = _extract_expression(sc_adata, genes=genes_to_predict, use_raw=params.use_raw, layer=params.layer) + sc_shared, _ = _extract_expression(sc_adata, genes=shared_genes, use_raw=use_raw, layer=layer) + st_shared, _ = _extract_expression(st_adata, genes=shared_genes, use_raw=use_raw, layer=layer) + sc_target, _ = _extract_expression(sc_adata, genes=genes_to_predict, use_raw=use_raw, layer=layer) sc_shared = _standardize(sc_shared) st_shared = _standardize(st_shared) - source_components = _fit_components(sc_shared, params.n_pv) - target_components = _fit_components(st_shared, params.n_pv) + source_components = _fit_components(sc_shared, n_pv) + target_components = _fit_components(st_shared, n_pv) source_components = _orthonormalize(source_components) target_components = _orthonormalize(target_components) - n_pv_eff = min(params.n_pv, source_components.shape[0], target_components.shape[0]) + n_pv_eff = min(n_pv, source_components.shape[0], target_components.shape[0]) if n_pv_eff <= 0: raise ValueError("No principal vectors could be computed.") source_pv, target_pv, cosine = _compute_principal_vectors(source_components, target_components, n_pv_eff) - effective_n_pv = int(np.sum(np.diag(cosine) > params.cosine_threshold)) + effective_n_pv = int(np.sum(np.diag(cosine) > cosine_threshold)) if effective_n_pv <= 0: raise ValueError("No effective principal vectors found. Consider lowering `cosine_threshold` or `n_pv`.") @@ -114,12 +113,12 @@ def spage_impute( sc_proj = _dot(sc_shared, S) st_proj = _dot(st_shared, S) - n_neighbors = min(params.n_neighbors, sc_proj.shape[0]) + n_neighbors = min(n_neighbors, sc_proj.shape[0]) nn = NearestNeighbors( n_neighbors=n_neighbors, metric="cosine", algorithm="auto", - n_jobs=params.n_jobs, + n_jobs=n_jobs, ) nn.fit(sc_proj) distances, indices = nn.kneighbors(st_proj, return_distance=True) diff --git a/tests/tools/test_impute.py b/tests/tools/test_impute.py index 3a471b38b..4880aa84f 100644 --- a/tests/tools/test_impute.py +++ b/tests/tools/test_impute.py @@ -7,7 +7,6 @@ from scipy.sparse import csr_matrix from squidpy.tl import impute -from squidpy.tl._spage_impute import SpaGEParams def _make_adata(n_obs: int, genes: list[str], rng: np.random.Generator) -> AnnData: @@ -28,7 +27,8 @@ def test_spage_impute_dense_copy(self): st_adata, sc_adata, method="spage", - method_params=SpaGEParams(n_pv=3, n_neighbors=5), + n_pv=3, + n_neighbors=5, key_added="spage", copy=True, ) @@ -253,7 +253,7 @@ def test_invalid_method_raises(self): with pytest.raises(ValueError, match="one of"): impute(st_adata, sc_adata, method="tangram") - def test_method_params_object_is_supported(self): + def test_spage_args_are_supported(self): rng = np.random.default_rng(14) sc_genes = [f"g{i}" for i in range(8)] st_genes = [f"g{i}" for i in range(5)] @@ -265,7 +265,8 @@ def test_method_params_object_is_supported(self): st_adata, sc_adata, method="spage", - method_params=SpaGEParams(n_pv=3, n_neighbors=4), + n_pv=3, + n_neighbors=4, copy=True, ) From 41fb28629bc45874652cd06f5195abda4ffdddb2 Mon Sep 17 00:00:00 2001 From: MDLDan Date: Thu, 16 Apr 2026 11:02:43 +0000 Subject: [PATCH 07/13] remove copy --- src/squidpy/tl/_impute.py | 3 --- src/squidpy/tl/_spage_impute.py | 5 ----- 2 files changed, 8 deletions(-) diff --git a/src/squidpy/tl/_impute.py b/src/squidpy/tl/_impute.py index 25443c372..c7b7fe0f9 100644 --- a/src/squidpy/tl/_impute.py +++ b/src/squidpy/tl/_impute.py @@ -28,7 +28,6 @@ def impute( layer: str | None = None, key_added: str = "spage", n_jobs: int | None = None, - copy: bool = False, ) -> AnnData: """ Impute spatially unmeasured genes in spatial data using a selected method. @@ -59,8 +58,6 @@ def impute( Key added to `.obsm` for the imputed genes. n_jobs Number of parallel jobs for nearest neighbors search. - copy - Whether to return a copy of `st_adata`. Returns ------- AnnData with imputed genes stored in `.obsm[key_added]`. diff --git a/src/squidpy/tl/_spage_impute.py b/src/squidpy/tl/_spage_impute.py index 78d05f044..f5cabee59 100644 --- a/src/squidpy/tl/_spage_impute.py +++ b/src/squidpy/tl/_spage_impute.py @@ -33,7 +33,6 @@ def spage_impute( layer: str | None = None, key_added: str = "spage", n_jobs: int | None = None, - copy: bool = False, ) -> AnnData: """ Impute spatially unmeasured genes in spatial data using SpaGE. @@ -60,8 +59,6 @@ def spage_impute( Key added to `.obsm` for the imputed genes. n_jobs Number of parallel jobs for nearest neighbors search. - copy - Whether to return a copy of `st_adata`. Returns ------- @@ -69,8 +66,6 @@ def spage_impute( """ start = logg.info("Running SpaGE imputation") - if copy: - st_adata = st_adata.copy() if n_pv <= 0: raise ValueError("`n_pv` must be positive.") From 0ff04cd8d599b58f63cf4ab8a4716ae4afc15853 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Apr 2026 11:03:40 +0000 Subject: [PATCH 08/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/tl/_spage_impute.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/squidpy/tl/_spage_impute.py b/src/squidpy/tl/_spage_impute.py index f5cabee59..ecd07bfde 100644 --- a/src/squidpy/tl/_spage_impute.py +++ b/src/squidpy/tl/_spage_impute.py @@ -66,7 +66,6 @@ def spage_impute( """ start = logg.info("Running SpaGE imputation") - if n_pv <= 0: raise ValueError("`n_pv` must be positive.") if n_neighbors <= 0: From 1c2e63ab45ba2c9df97fd1410f2100f5fd3f285d Mon Sep 17 00:00:00 2001 From: MDLDan Date: Thu, 16 Apr 2026 12:00:20 +0000 Subject: [PATCH 09/13] allow overlap imputation --- src/squidpy/tl/_spage_impute.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/squidpy/tl/_spage_impute.py b/src/squidpy/tl/_spage_impute.py index ecd07bfde..7eb05951f 100644 --- a/src/squidpy/tl/_spage_impute.py +++ b/src/squidpy/tl/_spage_impute.py @@ -129,17 +129,20 @@ def _resolve_genes_to_predict( st_adata: AnnData, sc_adata: AnnData, genes: Sequence[str] | None, + remove_shared: bool = False, ) -> list[str]: if genes is None: - genes_to_predict = [g for g in sc_adata.var_names if g not in st_adata.var_names] + genes_to_predict = [g for g in sc_adata.var_names] else: genes_to_predict = [g for g in genes if g in sc_adata.var_names] missing = [g for g in genes if g not in sc_adata.var_names] if missing: raise ValueError(f"Genes not found in `sc_adata`: {missing}") - genes_to_predict = [g for g in genes_to_predict if g not in st_adata.var_names] + genes_to_predict = [g for g in genes_to_predict] if not genes_to_predict: raise ValueError("No genes to impute. Ensure `genes` are in `sc_adata` and absent from `st_adata`.") + if remove_shared: + genes_to_predict = [g for g in genes_to_predict if g not in st_adata.var_names] return genes_to_predict From 3a9d5f98a44dbbe31b88775038e3295045ee2c68 Mon Sep 17 00:00:00 2001 From: MDLDan Date: Thu, 16 Apr 2026 12:02:33 +0000 Subject: [PATCH 10/13] allow overlap imputation --- src/squidpy/tl/_impute.py | 5 ++++- src/squidpy/tl/_spage_impute.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/squidpy/tl/_impute.py b/src/squidpy/tl/_impute.py index c7b7fe0f9..53fc33c75 100644 --- a/src/squidpy/tl/_impute.py +++ b/src/squidpy/tl/_impute.py @@ -28,6 +28,7 @@ def impute( layer: str | None = None, key_added: str = "spage", n_jobs: int | None = None, + remove_shared: bool = True, ) -> AnnData: """ Impute spatially unmeasured genes in spatial data using a selected method. @@ -58,6 +59,8 @@ def impute( Key added to `.obsm` for the imputed genes. n_jobs Number of parallel jobs for nearest neighbors search. + remove_shared + Whether to remove shared genes from the imputed gene set. By default, only genes that Returns ------- AnnData with imputed genes stored in `.obsm[key_added]`. @@ -76,7 +79,7 @@ def impute( layer=layer, key_added=key_added, n_jobs=n_jobs, - copy=copy, + remove_shared=remove_shared, ) raise NotImplementedError(f"Method `{method}` is not yet implemented.") diff --git a/src/squidpy/tl/_spage_impute.py b/src/squidpy/tl/_spage_impute.py index 7eb05951f..5b28e5714 100644 --- a/src/squidpy/tl/_spage_impute.py +++ b/src/squidpy/tl/_spage_impute.py @@ -33,6 +33,7 @@ def spage_impute( layer: str | None = None, key_added: str = "spage", n_jobs: int | None = None, + remove_shared: bool = True, ) -> AnnData: """ Impute spatially unmeasured genes in spatial data using SpaGE. @@ -59,6 +60,8 @@ def spage_impute( Key added to `.obsm` for the imputed genes. n_jobs Number of parallel jobs for nearest neighbors search. + remove_shared + Whether to remove shared genes from the imputed gene set. By default, only genes that are present in `sc_adata` but absent from `st_adata` are imputed. Returns ------- @@ -73,7 +76,7 @@ def spage_impute( if cosine_threshold < 0: raise ValueError("`cosine_threshold` must be non-negative.") - genes_to_predict = _resolve_genes_to_predict(st_adata, sc_adata, genes) + genes_to_predict = _resolve_genes_to_predict(st_adata, sc_adata, genes, remove_shared) shared_genes = _shared_genes(st_adata, sc_adata) if n_pv > len(shared_genes): From 18bc8fe3b7d86dca7de59277bf708aeebb1513a8 Mon Sep 17 00:00:00 2001 From: MDLDan Date: Thu, 16 Apr 2026 12:53:42 +0000 Subject: [PATCH 11/13] update tests --- tests/tools/test_impute.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/tools/test_impute.py b/tests/tools/test_impute.py index 4880aa84f..b0835bbc8 100644 --- a/tests/tools/test_impute.py +++ b/tests/tools/test_impute.py @@ -15,7 +15,7 @@ def _make_adata(n_obs: int, genes: list[str], rng: np.random.Generator) -> AnnDa class TestSpaGE: - def test_spage_impute_dense_copy(self): + def test_spage_impute_dense(self): rng = np.random.default_rng(0) sc_genes = [f"g{i}" for i in range(10)] st_genes = [f"g{i}" for i in range(5)] @@ -30,15 +30,15 @@ def test_spage_impute_dense_copy(self): n_pv=3, n_neighbors=5, key_added="spage", - copy=True, + remove_shared=False, ) - assert "spage" in res.obsm - assert "spage" not in st_adata.obsm + assert res is st_adata + assert "spage" in st_adata.obsm - df = res.obsm["spage"] - assert df.shape == (st_adata.n_obs, 5) - assert list(df.columns) == [f"g{i}" for i in range(5, 10)] + df = st_adata.obsm["spage"] + assert df.shape == (st_adata.n_obs, 10) + assert list(df.columns) == [f"g{i}" for i in range(10)] assert df.index.equals(st_adata.obs_names) def test_spage_impute_sparse(self): @@ -58,14 +58,14 @@ def test_spage_impute_sparse(self): n_pv=3, n_neighbors=4, key_added="spage", - copy=True, + remove_shared=False, ) df = res.obsm["spage"] - assert df.shape == (st_adata.n_obs, 4) - assert list(df.columns) == [f"g{i}" for i in range(4, 8)] + assert df.shape == (st_adata.n_obs, 8) + assert list(df.columns) == [f"g{i}" for i in range(8)] - def test_spage_impute_copy_false(self): + def test_spage_impute_returns_input(self): rng = np.random.default_rng(5) sc_genes = [f"g{i}" for i in range(9)] st_genes = [f"g{i}" for i in range(6)] @@ -80,7 +80,7 @@ def test_spage_impute_copy_false(self): n_pv=3, n_neighbors=4, key_added="spage", - copy=False, + remove_shared=False, ) assert res is st_adata @@ -103,7 +103,7 @@ def test_spage_impute_genes_subset_order(self): n_pv=3, n_neighbors=5, key_added="spage", - copy=True, + remove_shared=False, ) df = res.obsm["spage"] @@ -142,11 +142,11 @@ def test_spage_impute_n_neighbors_clamped(self): n_pv=3, n_neighbors=50, key_added="spage", - copy=True, + remove_shared=False, ) df = res.obsm["spage"] - assert df.shape == (st_adata.n_obs, 3) + assert df.shape == (st_adata.n_obs, 7) def test_spage_impute_use_raw(self): rng = np.random.default_rng(9) @@ -167,7 +167,7 @@ def test_spage_impute_use_raw(self): n_neighbors=4, key_added="spage", use_raw=True, - copy=True, + remove_shared=False, ) assert "spage" in res.obsm @@ -191,7 +191,7 @@ def test_spage_impute_layer(self): n_neighbors=4, key_added="spage", layer="counts", - copy=True, + remove_shared=False, ) assert "spage" in res.obsm @@ -227,7 +227,7 @@ def test_spage_impute_no_genes_to_impute(self): st_adata = _make_adata(10, st_genes, rng) with pytest.raises(ValueError, match="No genes to impute"): - impute(st_adata, sc_adata, method="spage", n_pv=2, n_neighbors=3) + impute(st_adata, sc_adata, method="spage", genes=[], n_pv=2, n_neighbors=3) def test_spage_impute_n_pv_too_large(self): rng = np.random.default_rng(4) @@ -267,7 +267,7 @@ def test_spage_args_are_supported(self): method="spage", n_pv=3, n_neighbors=4, - copy=True, + remove_shared=False, ) assert "spage" in res.obsm From 48b8e32ceca4b292a81a8eb9d75cc9f50c7c4a20 Mon Sep 17 00:00:00 2001 From: MDLDan Date: Fri, 17 Apr 2026 12:54:53 +0000 Subject: [PATCH 12/13] add copy --- src/squidpy/tl/_impute.py | 14 +++--- src/squidpy/tl/_spage_impute.py | 25 +++++------ tests/tools/test_impute.py | 75 +++++++++++++++++++++++++-------- 3 files changed, 79 insertions(+), 35 deletions(-) diff --git a/src/squidpy/tl/_impute.py b/src/squidpy/tl/_impute.py index 53fc33c75..d99a1deb4 100644 --- a/src/squidpy/tl/_impute.py +++ b/src/squidpy/tl/_impute.py @@ -3,6 +3,7 @@ from collections.abc import Sequence from anndata import AnnData +import pandas as pd from squidpy._docs import d from squidpy._validators import assert_one_of @@ -28,8 +29,8 @@ def impute( layer: str | None = None, key_added: str = "spage", n_jobs: int | None = None, - remove_shared: bool = True, -) -> AnnData: + copy: bool = False, +) -> pd.DataFrame | None: """ Impute spatially unmeasured genes in spatial data using a selected method. @@ -59,11 +60,12 @@ def impute( Key added to `.obsm` for the imputed genes. n_jobs Number of parallel jobs for nearest neighbors search. - remove_shared - Whether to remove shared genes from the imputed gene set. By default, only genes that + copy + If `True`, return the imputed dataframe. Otherwise, save it to `.obsm[key_added]`. Returns ------- - AnnData with imputed genes stored in `.obsm[key_added]`. + If ``copy = True``, returns a :class:`pandas.DataFrame` with imputed values. + Otherwise, stores the result in :attr:`anndata.AnnData.obsm` ``[key_added]`` and returns `None`. """ assert_one_of(method, _ALLOWED_METHODS, name="method") @@ -79,7 +81,7 @@ def impute( layer=layer, key_added=key_added, n_jobs=n_jobs, - remove_shared=remove_shared, + copy=copy, ) raise NotImplementedError(f"Method `{method}` is not yet implemented.") diff --git a/src/squidpy/tl/_spage_impute.py b/src/squidpy/tl/_spage_impute.py index 5b28e5714..ce7aaf855 100644 --- a/src/squidpy/tl/_spage_impute.py +++ b/src/squidpy/tl/_spage_impute.py @@ -33,8 +33,8 @@ def spage_impute( layer: str | None = None, key_added: str = "spage", n_jobs: int | None = None, - remove_shared: bool = True, -) -> AnnData: + copy: bool = False, +) -> pd.DataFrame | None: """ Impute spatially unmeasured genes in spatial data using SpaGE. @@ -60,12 +60,13 @@ def spage_impute( Key added to `.obsm` for the imputed genes. n_jobs Number of parallel jobs for nearest neighbors search. - remove_shared - Whether to remove shared genes from the imputed gene set. By default, only genes that are present in `sc_adata` but absent from `st_adata` are imputed. + copy + If `True`, return the imputed dataframe. Otherwise, save it to `.obsm[key_added]`. Returns ------- - AnnData with imputed genes stored in `.obsm[key_added]`. + If ``copy = True``, returns a :class:`pandas.DataFrame` with imputed values. + Otherwise, stores the result in :attr:`anndata.AnnData.obsm` ``[key_added]`` and returns `None`. """ start = logg.info("Running SpaGE imputation") @@ -76,7 +77,7 @@ def spage_impute( if cosine_threshold < 0: raise ValueError("`cosine_threshold` must be non-negative.") - genes_to_predict = _resolve_genes_to_predict(st_adata, sc_adata, genes, remove_shared) + genes_to_predict = _resolve_genes_to_predict(sc_adata, genes) shared_genes = _shared_genes(st_adata, sc_adata) if n_pv > len(shared_genes): @@ -124,15 +125,17 @@ def spage_impute( imputed = _impute_from_neighbors(weights, mask, indices, sc_target) result = pd.DataFrame(imputed, index=st_adata.obs_names, columns=genes_to_predict) + if copy: + logg.info("Finish", time=start) + return result + _save_data(st_adata, attr="obsm", key=key_added, data=result, time=start) - return st_adata + return None def _resolve_genes_to_predict( - st_adata: AnnData, sc_adata: AnnData, genes: Sequence[str] | None, - remove_shared: bool = False, ) -> list[str]: if genes is None: genes_to_predict = [g for g in sc_adata.var_names] @@ -143,9 +146,7 @@ def _resolve_genes_to_predict( raise ValueError(f"Genes not found in `sc_adata`: {missing}") genes_to_predict = [g for g in genes_to_predict] if not genes_to_predict: - raise ValueError("No genes to impute. Ensure `genes` are in `sc_adata` and absent from `st_adata`.") - if remove_shared: - genes_to_predict = [g for g in genes_to_predict if g not in st_adata.var_names] + raise ValueError("No genes to impute. Ensure `genes` are in `sc_adata`.") return genes_to_predict diff --git a/tests/tools/test_impute.py b/tests/tools/test_impute.py index b0835bbc8..5c43cd57c 100644 --- a/tests/tools/test_impute.py +++ b/tests/tools/test_impute.py @@ -30,10 +30,9 @@ def test_spage_impute_dense(self): n_pv=3, n_neighbors=5, key_added="spage", - remove_shared=False, ) - assert res is st_adata + assert res is None assert "spage" in st_adata.obsm df = st_adata.obsm["spage"] @@ -58,14 +57,14 @@ def test_spage_impute_sparse(self): n_pv=3, n_neighbors=4, key_added="spage", - remove_shared=False, ) - df = res.obsm["spage"] + assert res is None + df = st_adata.obsm["spage"] assert df.shape == (st_adata.n_obs, 8) assert list(df.columns) == [f"g{i}" for i in range(8)] - def test_spage_impute_returns_input(self): + def test_spage_impute_in_place_write(self): rng = np.random.default_rng(5) sc_genes = [f"g{i}" for i in range(9)] st_genes = [f"g{i}" for i in range(6)] @@ -80,12 +79,33 @@ def test_spage_impute_returns_input(self): n_pv=3, n_neighbors=4, key_added="spage", - remove_shared=False, ) - assert res is st_adata + assert res is None assert "spage" in st_adata.obsm + def test_spage_impute_copy_returns_dataframe(self): + rng = np.random.default_rng(15) + sc_genes = [f"g{i}" for i in range(9)] + st_genes = [f"g{i}" for i in range(6)] + + sc_adata = _make_adata(25, sc_genes, rng) + st_adata = _make_adata(12, st_genes, rng) + + res = impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=4, + key_added="spage", + copy=True, + ) + + assert isinstance(res, pd.DataFrame) + assert res.shape == (st_adata.n_obs, 9) + assert "spage" not in st_adata.obsm + def test_spage_impute_genes_subset_order(self): rng = np.random.default_rng(6) sc_genes = [f"g{i}" for i in range(10)] @@ -103,10 +123,10 @@ def test_spage_impute_genes_subset_order(self): n_pv=3, n_neighbors=5, key_added="spage", - remove_shared=False, ) - df = res.obsm["spage"] + assert res is None + df = st_adata.obsm["spage"] assert list(df.columns) == genes def test_spage_impute_cosine_threshold_too_strict(self): @@ -142,10 +162,10 @@ def test_spage_impute_n_neighbors_clamped(self): n_pv=3, n_neighbors=50, key_added="spage", - remove_shared=False, ) - df = res.obsm["spage"] + assert res is None + df = st_adata.obsm["spage"] assert df.shape == (st_adata.n_obs, 7) def test_spage_impute_use_raw(self): @@ -167,10 +187,10 @@ def test_spage_impute_use_raw(self): n_neighbors=4, key_added="spage", use_raw=True, - remove_shared=False, ) - assert "spage" in res.obsm + assert res is None + assert "spage" in st_adata.obsm def test_spage_impute_layer(self): rng = np.random.default_rng(10) @@ -191,10 +211,31 @@ def test_spage_impute_layer(self): n_neighbors=4, key_added="spage", layer="counts", - remove_shared=False, ) - assert "spage" in res.obsm + assert res is None + assert "spage" in st_adata.obsm + + def test_spage_impute_shared_genes_are_kept(self): + rng = np.random.default_rng(16) + sc_genes = [f"g{i}" for i in range(10)] + st_genes = [f"g{i}" for i in range(5)] + + sc_adata = _make_adata(40, sc_genes, rng) + st_adata = _make_adata(20, st_genes, rng) + + impute( + st_adata, + sc_adata, + method="spage", + n_pv=3, + n_neighbors=5, + key_added="spage", + ) + + df = st_adata.obsm["spage"] + assert df.shape == (st_adata.n_obs, 10) + assert all(gene in df.columns for gene in st_genes) def test_spage_impute_invalid_genes(self): rng = np.random.default_rng(2) @@ -267,7 +308,7 @@ def test_spage_args_are_supported(self): method="spage", n_pv=3, n_neighbors=4, - remove_shared=False, ) - assert "spage" in res.obsm + assert res is None + assert "spage" in st_adata.obsm From 772ce91abb10f112156c253086c660445d286ffc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Apr 2026 12:55:08 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/tl/_impute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/squidpy/tl/_impute.py b/src/squidpy/tl/_impute.py index d99a1deb4..e9d58b3f4 100644 --- a/src/squidpy/tl/_impute.py +++ b/src/squidpy/tl/_impute.py @@ -2,8 +2,8 @@ from collections.abc import Sequence -from anndata import AnnData import pandas as pd +from anndata import AnnData from squidpy._docs import d from squidpy._validators import assert_one_of