diff --git a/docs/release-notes/4037.feat.md b/docs/release-notes/4037.feat.md new file mode 100644 index 0000000000..2fbcfeb684 --- /dev/null +++ b/docs/release-notes/4037.feat.md @@ -0,0 +1 @@ +Add `mean_in_log_space` argument to {func}`scanpy.tl.rank_genes_groups` for customizing how log-fold-change is calculated {user}`ilan-gold` diff --git a/src/scanpy/_settings/presets.py b/src/scanpy/_settings/presets.py index bef0280b39..f7a243d61a 100644 --- a/src/scanpy/_settings/presets.py +++ b/src/scanpy/_settings/presets.py @@ -81,6 +81,7 @@ class PcaPreset(NamedTuple): class RankGenesGroupsPreset(NamedTuple): method: DETest mask_var: str | None + mean_in_log_space: bool class ScalePreset(NamedTuple): @@ -185,9 +186,11 @@ def pca() -> Mapping[Preset, PcaPreset]: def rank_genes_groups() -> Mapping[Preset, RankGenesGroupsPreset]: """Correlation method for :func:`~scanpy.tl.rank_genes_groups`.""" return { - Preset.ScanpyV1: RankGenesGroupsPreset(method="t-test", mask_var=None), + Preset.ScanpyV1: RankGenesGroupsPreset( + method="t-test", mask_var=None, mean_in_log_space=True + ), Preset.ScanpyV2Preview: RankGenesGroupsPreset( - method="wilcoxon", mask_var=None + method="wilcoxon", mask_var=None, mean_in_log_space=False ), } diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index d1eec5d4a0..067795158d 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -201,7 +201,7 @@ def __init__( self.grouping_mask = adata.obs[groupby].isin(self.groups_order) self.grouping = adata.obs.loc[self.grouping_mask, groupby] - def _basic_stats(self) -> None: + def _basic_stats(self, *, exponentiate_values: bool = False) -> None: """Set self.{means,vars,pts}{,_rest} depending on X.""" n_genes = self.X.shape[1] n_groups = self.groups_masks_obs.shape[0] @@ -217,6 +217,8 @@ def _basic_stats(self) -> None: else: mask_rest = self.groups_masks_obs[self.ireference] x_rest = self.X[mask_rest] + if exponentiate_values: + x_rest = self.expm1_func(x_rest) self.means[self.ireference], self.vars[self.ireference] = mean_var( x_rest, axis=0, correction=1 ) @@ -230,6 +232,8 @@ def _basic_stats(self) -> None: for group_index, mask_obs in enumerate(self.groups_masks_obs): x_mask = self.X[mask_obs] + if exponentiate_values: + x_mask = self.expm1_func(x_mask) if self.comp_pts: self.pts[group_index] = get_nonzeros(x_mask) / x_mask.shape[0] @@ -244,6 +248,8 @@ def _basic_stats(self) -> None: if self.ireference is None: mask_rest = ~mask_obs x_rest = self.X[mask_rest] + if exponentiate_values: + x_rest = self.expm1_func(x_rest) ( self.means_rest[group_index], self.vars_rest[group_index], @@ -259,8 +265,6 @@ def t_test( ) -> Generator[tuple[int, NDArray[np.floating], NDArray[np.floating]], None, None]: from scipy import stats - self._basic_stats() - for group_index, (mask_obs, mean_group, var_group) in enumerate( zip(self.groups_masks_obs, self.means, self.vars, strict=True) ): @@ -312,8 +316,6 @@ def wilcoxon( ) -> Generator[tuple[int, NDArray[np.floating], NDArray[np.floating]], None, None]: from scipy import stats - self._basic_stats() - n_genes = self.X.shape[1] # First loop: Loop over all genes if self.ireference is not None: @@ -429,12 +431,19 @@ def compute_statistics( # noqa: PLR0912 n_genes_user: int | None = None, rankby_abs: bool = False, tie_correct: bool = False, + mean_in_log_space: bool = True, **kwds, ) -> None: if method in {"t-test", "t-test_overestim_var"}: + self._basic_stats(exponentiate_values=False) generate_test_results = self.t_test(method) + if not mean_in_log_space: + # If we are not exponentiating after the mean aggregation, we need to recalculate the stats. + self._basic_stats(exponentiate_values=True) elif method == "wilcoxon": generate_test_results = self.wilcoxon(tie_correct=tie_correct) + # If we're not exponentiating after the mean aggregation, then do it now. + self._basic_stats(exponentiate_values=not mean_in_log_space) elif method == "logreg": generate_test_results = self.logreg(**kwds) @@ -481,9 +490,12 @@ def compute_statistics( # noqa: PLR0912 mean_rest = self.means_rest[group_index] else: mean_rest = self.means[self.ireference] - foldchanges = (self.expm1_func(mean_group) + 1e-9) / ( - self.expm1_func(mean_rest) + 1e-9 - ) # add small value to remove 0's + foldchanges = ( + (self.expm1_func(mean_group) + 1e-9) + / (self.expm1_func(mean_rest) + 1e-9) + if mean_in_log_space + else (mean_group + 1e-9) / (mean_rest + 1e-9) + ) # add small value to avoid zeros self.stats[group_name, "logfoldchanges"] = np.log2( foldchanges[global_indices] ) @@ -511,9 +523,12 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915 corr_method: _CorrMethod = "benjamini-hochberg", tie_correct: bool = False, layer: str | None = None, + mean_in_log_space: bool | Default = Default( + preset=("rank_genes_groups", "mean_in_log_space") + ), **kwds, ) -> AnnData | None: - """Rank genes for characterizing groups. + r"""Rank genes for characterizing groups. Expects logarithmized data. @@ -574,6 +589,11 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915 The key in `adata.uns` information is saved to. copy Whether to copy `adata` or modify it inplace. + mean_in_log_space + Whether to do :math:`\log(\operatorname{mean}(e^x))` (`False`) + or :math:`\log(e^{\operatorname{mean}(x)})` (`True`). + The former is accurate, while the latter is a faster approximation + that underestimates this accurate result in the presence of many outliers. kwds Are passed to test methods. Currently this affects only parameters that are passed to :class:`sklearn.linear_model.LogisticRegression`. @@ -596,7 +616,7 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915 Structured array to be indexed by group id storing the log2 fold change for each gene for each group. Ordered according to scores. Only provided if method is 't-test' like. - Note: this is an approximation calculated from mean-log values. + Note: if `mean_in_log_space=True`, this is an approximation calculated from mean-log values. `adata.uns['rank_genes_groups' | key_added]['pvals']` : structured :class:`numpy.ndarray` (dtype `float`) p-values. `adata.uns['rank_genes_groups' | key_added]['pvals_adj']` : structured :class:`numpy.ndarray` (dtype `float`) @@ -626,6 +646,8 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915 if isinstance(mask_var, Default): mask_var = settings.preset.rank_genes_groups.mask_var + if isinstance(mean_in_log_space, Default): + mean_in_log_space = settings.preset.rank_genes_groups.mean_in_log_space if method is None or isinstance(method, Default): method = settings.preset.rank_genes_groups.method @@ -714,6 +736,7 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915 n_genes_user=n_genes_user, rankby_abs=rankby_abs, tie_correct=tie_correct, + mean_in_log_space=mean_in_log_space, **kwds, ) diff --git a/tests/test_rank_genes_groups.py b/tests/test_rank_genes_groups.py index ba38ffc94d..a60f914e2d 100644 --- a/tests/test_rank_genes_groups.py +++ b/tests/test_rank_genes_groups.py @@ -311,3 +311,44 @@ def test_mask_not_equal(): with_mask = pbmc.uns["rank_genes_groups"]["names"] assert not np.array_equal(no_mask, with_mask) + + +@pytest.mark.parametrize( + ("mean_in_log_space", "expected_logfc"), + [ + # exp after agg: log2(expm1(mean_log_a) / expm1(mean_log_b)) + # = log2(expm1(ln(9) * 5 / 10) / expm1(ln9)) = log2(2 / 8) = -2.0 + (True, -2.0), + # exp before agg: log2(mean(expm1(linear_a)) / mean(expm1(linear_b))) + # = log2(mean([0] * 5 + [8] * 5) / mean([8] * 10)) = log2(4 / 8) = -1.0 + (False, -1.0), + ], +) +@pytest.mark.parametrize("method", ["wilcoxon", "t-test", "t-test_overestim_var"]) +def test_mean_in_log_space( + expected_logfc: float, + method: Literal["wilcoxon", "t-test", "t-test_overestim_var"], + *, + mean_in_log_space: bool, +): + # group_a: 5 cells with log-space value 0, 5 cells with log(9) + # group_b: 10 cells all with log(9) (used as reference) + n_genes = 5 + group_a = np.zeros((10, n_genes)) + group_a[5:] = np.log(9) + group_b = np.full((10, n_genes), np.log(9)) + adata = AnnData( + X=np.concatenate([group_a, group_b]), + obs={"bulk_labels": ["a"] * 10 + ["b"] * 10}, + ) + + rank_genes_groups( + adata, + groupby="bulk_labels", + groups=["a"], + reference="b", + method=method, + mean_in_log_space=mean_in_log_space, + ) + logfcs = adata.uns["rank_genes_groups"]["logfoldchanges"]["a"] + np.testing.assert_equal(logfcs, expected_logfc)