diff --git a/pyproject.toml b/pyproject.toml index fa6c66534e..f3524c1152 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ classifiers = [ ] dynamic = [ "version" ] dependencies = [ - "anndata>=0.10.8", + "anndata>=0.11", "certifi", "fast-array-utils[accel,sparse]>=1.4", "h5py>=3.11", @@ -97,6 +97,7 @@ scanorama = [ "scanorama" ] scrublet = [ "scikit-image>=0.23.1" ] # highly_variable_genes method 'seurat_v3' skmisc = [ "scikit-misc>=0.5.1" ] +illico = [ "illico>=0.5.0rc2" ] scanpy2 = [ "igraph>=0.10.8", "scikit-misc>=0.5.1" ] [dependency-groups] @@ -107,6 +108,7 @@ dev = [ test = [ "scanpy[dask-ml]", "scanpy[dask]", + "scanpy[illico]", "scanpy[leiden]", "scanpy[plotting]", "scanpy[scrublet]", diff --git a/src/scanpy/_settings/presets.py b/src/scanpy/_settings/presets.py index f7a243d61a..5a45e250a2 100644 --- a/src/scanpy/_settings/presets.py +++ b/src/scanpy/_settings/presets.py @@ -31,7 +31,9 @@ ] -type DETest = Literal["logreg", "t-test", "wilcoxon", "t-test_overestim_var"] +type DETest = Literal[ + "logreg", "t-test", "wilcoxon", "wilcoxon_illico", "t-test_overestim_var" +] type HVGFlavor = Literal["seurat", "cell_ranger", "seurat_v3", "seurat_v3_paper"] type LeidenFlavor = Literal["leidenalg", "igraph"] diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 067795158d..31997ced16 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -7,6 +7,7 @@ import numba import numpy as np import pandas as pd +from anndata import AnnData from fast_array_utils.numba import njit from fast_array_utils.stats import mean_var from scipy import sparse @@ -27,7 +28,6 @@ from collections.abc import Generator, Iterable from typing import Literal - from anndata import AnnData from numpy.typing import NDArray @@ -140,6 +140,7 @@ def __init__( self.expm1_func = lambda x: np.expm1(x * np.log(base)) else: self.expm1_func = np.expm1 + self.group_col = adata.obs[groupby].array self.groups_order, self.groups_masks_obs = _utils.select_groups( adata, groups, groupby @@ -423,7 +424,7 @@ def logreg( if len(self.groups_order) <= 2: break - def compute_statistics( # noqa: PLR0912 + def compute_statistics( # noqa: PLR0912, PLR0915 self, method: DETest, *, @@ -440,8 +441,63 @@ def compute_statistics( # noqa: PLR0912 if not mean_in_log_space: # If we are not exponentiating after the mean aggregation, we need to recalculate the stats. self._basic_stats(exponentiate_values=True) - elif method == "wilcoxon": - generate_test_results = self.wilcoxon(tie_correct=tie_correct) + elif "wilcoxon" in method: + if "illico" in method: + from illico import asymptotic_wilcoxon + + illico_df = asymptotic_wilcoxon( + AnnData( + X=self.X, + var=pd.DataFrame(index=self.var_names), + obs=pd.DataFrame( + index=pd.RangeIndex(self.X.shape[0]).astype("str"), + # This self.group_col means illico will run tests against *all* data + # instead of what's in self.groups_order as controlled by the `groups` arg. + # TODO: Only run the subset once illico supports a `groups` argument + data={"group": self.group_col}, + ), + ), + reference=self.groups_order[self.ireference] + if self.ireference is not None + else None, + group_keys="group", + return_as_scanpy=False, + is_log1p=True, + tie_correct=tie_correct, + use_continuity=False, + alternative="two-sided", + use_rust=False, + ) + # Generate a lookup of category -> result excluding the refernece if it is present. + generate_test_results_map = { + group_cat: ( + group["z_score"].to_numpy(copy=True), + group["p_value"].to_numpy(copy=True), + ) + for (_, group) in illico_df.groupby(level="pert") + if ( + group_cat := np.unique( + group.index.get_level_values("pert").to_numpy(copy=True) + ).item() + ) + != ( + None + if self.ireference is None + else self.groups_order[self.ireference] + ) + } + # Create the iterator that is expected by the other method-branches. + groups_order_list = self.groups_order.tolist() + generate_test_results = ( + ( + groups_order_list.index(group_cat), + *generate_test_results_map[group_cat], + ) + for group_cat in self.groups_order + if group_cat in generate_test_results_map + ) + else: + generate_test_results = self.wilcoxon(tie_correct=tie_correct) # If we're not exponentiating after the mean aggregation, then do it now. self._basic_stats(exponentiate_values=not mean_in_log_space) elif method == "logreg": @@ -450,7 +506,6 @@ def compute_statistics( # noqa: PLR0912 self.stats = None n_genes = self.X.shape[1] - for group_index, scores, pvals in generate_test_results: group_name = str(self.groups_order[group_index]) diff --git a/src/testing/scanpy/_pytest/marks.py b/src/testing/scanpy/_pytest/marks.py index 1e83404614..654340f604 100644 --- a/src/testing/scanpy/_pytest/marks.py +++ b/src/testing/scanpy/_pytest/marks.py @@ -41,6 +41,7 @@ def _generate_next_value_( skimage = "scikit-image" skmisc = "scikit-misc" zarr = auto() + illico = auto() # external bbknn = auto() harmony = "harmonyTS" diff --git a/tests/test_rank_genes_groups.py b/tests/test_rank_genes_groups.py index a60f914e2d..b6c1c461aa 100644 --- a/tests/test_rank_genes_groups.py +++ b/tests/test_rank_genes_groups.py @@ -19,10 +19,11 @@ from scanpy.tools._rank_genes_groups import _RankGenes from testing.scanpy._helpers import random_mask from testing.scanpy._helpers.data import pbmc68k_reduced +from testing.scanpy._pytest.marks import needs from testing.scanpy._pytest.params import ARRAY_TYPES, ARRAY_TYPES_MEM if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Sequence from typing import Any, Literal from numpy.lib.npyio import NpzFile @@ -313,6 +314,79 @@ def test_mask_not_equal(): assert not np.array_equal(no_mask, with_mask) +@pytest.mark.parametrize("corr_method", ["benjamini-hochberg", "bonferroni"]) +@pytest.mark.parametrize("test", ["ovo", "ovr"]) +@pytest.mark.parametrize("exp_post_agg", [True, False], ids=["post_exp", "pre_exp"]) +@pytest.mark.parametrize( + "tie_correct", [True, False], ids=["tie_correct", "no_tie_correct"] +) +@pytest.mark.parametrize("groups", [["CD14+ Monocyte", "Dendritic"], "all"]) +@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +@needs.illico +def test_illico( + test: Literal["ovo", "ovr"], + corr_method: Literal["benjamini-hochberg", "bonferroni"], + subtests: pytest.Subtests, + groups: Literal["all"] | Sequence[str], + *, + exp_post_agg: bool, + tie_correct: bool, +): + + pbmc = pbmc68k_reduced() + pbmc.raw.X.sum_duplicates() + pbmc.raw.X.sort_indices() + pbmc_illico = pbmc.copy() + + reference = pbmc.obs["bulk_labels"].iloc[0] if test == "ovo" else "rest" + sc.tl.rank_genes_groups( + pbmc_illico, + groupby="bulk_labels", + method="wilcoxon_illico", + reference=reference if test == "ovo" else "rest", + n_genes=pbmc.n_vars, + tie_correct=tie_correct, + corr_method=corr_method, + exp_post_agg=exp_post_agg, + groups=groups, + ) + + sc.tl.rank_genes_groups( + pbmc, + groupby="bulk_labels", + method="wilcoxon", + reference=reference if test == "ovo" else "rest", + n_genes=pbmc.n_vars, + tie_correct=tie_correct, + corr_method=corr_method, + exp_post_agg=exp_post_agg, + groups=groups, + ) + scanpy_results = pbmc.uns["rank_genes_groups"] + illico_results = pbmc_illico.uns["rank_genes_groups"] + assert set(illico_results.keys()) == set(scanpy_results.keys()), ( + "Output keys do not match Scanpy's output format." + ) + + for k, ref in scanpy_results.items(): + with subtests.test(k): + if k in ["params", "names"]: + # We can skip names ordering check as if incorrect, other values will mismatch + continue + res = np.array(illico_results[k].tolist()) + ref_arr = np.array(ref.tolist()) + mask = np.isfinite(ref_arr) * np.isfinite( + res + ) # Mask to ignore inf values in the comparison + np.testing.assert_allclose( + ref_arr[mask], + res[mask], + rtol=0, + atol=1e-6, + err_msg=f"Mismatch in '{k}' values between asymptotic_wilcoxon and Scanpy outputs.", + ) + + @pytest.mark.parametrize( ("mean_in_log_space", "expected_logfc"), [