From ada140a37589138995f98562d4a5eda5733bc10d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 17:46:46 +0200 Subject: [PATCH 01/23] feat: speed up numba sums --- src/scanpy/get/_aggregated.py | 38 +++++++++++------------ src/scanpy/get/_kernels.py | 58 +++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 19 deletions(-) create mode 100644 src/scanpy/get/_kernels.py diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index c5750a8193..601bd2a94c 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -13,6 +13,7 @@ from scanpy._compat import CSBase, CSRBase, DaskArray from .._utils import _resolve_axis, get_literal_vals +from ._kernels import agg_sum_csc, agg_sum_csr from .get import _check_mask if TYPE_CHECKING: @@ -53,7 +54,7 @@ class Aggregate: def __init__( self, groupby: pd.Categorical, - data: Array, + data: np.ndarray | CSBase, *, mask: NDArray[np.bool] | None = None, ) -> None: @@ -64,7 +65,7 @@ def __init__( self.data = data groupby: pd.Categorical - indicator_matrix: sparse.coo_matrix + indicator_matrix: CSRBase data: Array def count_nonzero(self) -> NDArray[np.integer]: @@ -79,7 +80,7 @@ def count_nonzero(self) -> NDArray[np.integer]: # return self.indicator_matrix @ pattern return utils.asarray(self.indicator_matrix @ (self.data != 0)) - def sum(self) -> Array: + def sum(self, *, power_of_2: bool = False) -> Array: """Compute the sum per feature per group of observations. Returns @@ -87,7 +88,14 @@ def sum(self) -> Array: Array of sum. """ - return utils.asarray(self.indicator_matrix @ self.data) + if isinstance(self.data, np.ndarray): + return utils.asarray( + self.indicator_matrix + @ (_power(self.data, 2) if power_of_2 else self.data) + ) + return (agg_sum_csr if isinstance(self.data, CSRBase) else agg_sum_csc)( + self.indicator_matrix, (_power(self.data, 2) if power_of_2 else self.data) + ) def mean(self) -> Array: """Compute the mean per feature per group of observations. @@ -97,10 +105,7 @@ def mean(self) -> Array: Array of mean. """ - return ( - utils.asarray(self.indicator_matrix @ self.data) - / np.bincount(self.groupby.codes)[:, None] - ) + return self.sum() / np.bincount(self.groupby.codes)[:, None] def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: """Compute the count, as well as mean and variance per feature, per group of observations. @@ -550,18 +555,13 @@ def sparse_indicator( categorical: pd.Categorical, *, mask: NDArray[np.bool] | None = None, - weight: NDArray[np.floating] | None = None, -) -> sparse.coo_matrix: - if mask is not None and weight is None: - weight = mask.astype(np.float32) - elif mask is not None and weight is not None: - weight = mask * weight - elif mask is None and weight is None: - weight = np.broadcast_to(1.0, len(categorical)) +) -> CSRBase: + if mask is None: + mask = np.broadcast_to(True, len(categorical)) # noqa: FBT003 # can’t have -1s in the codes, but (as long as it’s valid), the value is ignored, so set to 0 where masked codes = categorical.codes if mask is None else np.where(mask, categorical.codes, 0) - a = sparse.coo_matrix( - (weight, (codes, np.arange(len(categorical)))), + a = sparse.coo_array( + (mask, (codes, np.arange(len(categorical)))), shape=(len(categorical.categories), len(categorical)), - ) + ).tocsr() return a diff --git a/src/scanpy/get/_kernels.py b/src/scanpy/get/_kernels.py new file mode 100644 index 0000000000..3e3fb0147d --- /dev/null +++ b/src/scanpy/get/_kernels.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numba +import numpy as np +from fast_array_utils.numba import njit + +if TYPE_CHECKING: + from .._compat import CSCBase, CSRBase + + +@njit +def agg_sum_csr( + indicator: CSRBase, + data: CSRBase, +): + out = np.zeros((indicator.shape[0], data.shape[1])) + for cat_num in numba.prange(indicator.shape[0]): + start_cat_idx = indicator.indptr[cat_num] + stop_cat_idx = indicator.indptr[cat_num + 1] + for row_num in range(start_cat_idx, stop_cat_idx): + obs_per_cat = indicator.indices[row_num] + + start_obs = data.indptr[obs_per_cat] + end_obs = data.indptr[obs_per_cat + 1] + + for j in range(start_obs, end_obs): + col = data.indices[j] + out[cat_num, col] += data.data[j] + return out + + +@njit +def agg_sum_csc( + indicator: CSRBase, + data: CSCBase, +): + out = np.zeros((indicator.shape[0], data.shape[1]), dtype=data.data.dtype) + + obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64) + + for cat in range(indicator.shape[0]): + for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]): + obs_to_cat[indicator.indices[k]] = cat + + for col in numba.prange(data.shape[1]): + start = data.indptr[col] + end = data.indptr[col + 1] + + for j in range(start, end): + obs = data.indices[j] + cat = obs_to_cat[obs] + + if cat != -1: + out[cat, col] += data.data[j] + + return out From 7a8aaffc4ddf328ffd42a8d4de0cd0e8a62271ea Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 19:30:48 +0200 Subject: [PATCH 02/23] fix: no use keeping booleans around --- src/scanpy/get/_aggregated.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 601bd2a94c..df4309008c 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -558,6 +558,7 @@ def sparse_indicator( ) -> CSRBase: if mask is None: mask = np.broadcast_to(True, len(categorical)) # noqa: FBT003 + mask = mask.astype("uint8") # can’t have -1s in the codes, but (as long as it’s valid), the value is ignored, so set to 0 where masked codes = categorical.codes if mask is None else np.where(mask, categorical.codes, 0) a = sparse.coo_array( From ce44fbea8c147ed9cef7ee83035e25da8f245679 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 19:31:45 +0200 Subject: [PATCH 03/23] chore: remove dead code --- src/scanpy/get/_aggregated.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index df4309008c..86540cadf3 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -560,7 +560,7 @@ def sparse_indicator( mask = np.broadcast_to(True, len(categorical)) # noqa: FBT003 mask = mask.astype("uint8") # can’t have -1s in the codes, but (as long as it’s valid), the value is ignored, so set to 0 where masked - codes = categorical.codes if mask is None else np.where(mask, categorical.codes, 0) + codes = np.where(mask, categorical.codes, 0) a = sparse.coo_array( (mask, (codes, np.arange(len(categorical)))), shape=(len(categorical.categories), len(categorical)), From be47e37e03c6c531b79efdcf6ce17cba093843a6 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 19:53:40 +0200 Subject: [PATCH 04/23] fix: dtype issue --- src/scanpy/get/_aggregated.py | 23 +++++++++++------------ src/scanpy/get/_kernels.py | 8 ++++---- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 86540cadf3..7f0b78b9c2 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -76,11 +76,9 @@ def count_nonzero(self) -> NDArray[np.integer]: Array of counts. """ - # pattern = self.data._with_data(np.broadcast_to(1, len(self.data.data))) - # return self.indicator_matrix @ pattern - return utils.asarray(self.indicator_matrix @ (self.data != 0)) + return self._sum(data=(self.data != 0).astype("uint8"), power_of_2=False) - def sum(self, *, power_of_2: bool = False) -> Array: + def sum(self, *, power_of_2: bool = False) -> np.ndarray: """Compute the sum per feature per group of observations. Returns @@ -88,13 +86,17 @@ def sum(self, *, power_of_2: bool = False) -> Array: Array of sum. """ - if isinstance(self.data, np.ndarray): + return self._sum(data=self.data, power_of_2=power_of_2) + + def _sum(self, *, data: np.ndarray | CSBase, power_of_2: bool = False) -> np.ndarray: + + if isinstance(data, np.ndarray): return utils.asarray( self.indicator_matrix - @ (_power(self.data, 2) if power_of_2 else self.data) + @ (_power(data, 2) if power_of_2 else data) ) - return (agg_sum_csr if isinstance(self.data, CSRBase) else agg_sum_csc)( - self.indicator_matrix, (_power(self.data, 2) if power_of_2 else self.data) + return (agg_sum_csr if isinstance(data, CSRBase) else agg_sum_csc)( + self.indicator_matrix, (_power(data, 2) if power_of_2 else data) ) def mean(self) -> Array: @@ -131,10 +133,7 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: group_counts = np.bincount(self.groupby.codes) mean_ = self.mean() # sparse matrices do not support ** for elementwise power. - mean_sq = ( - utils.asarray(self.indicator_matrix @ _power(self.data, 2)) - / group_counts[:, None] - ) + mean_sq = self.sum(power_of_2=True) / group_counts[:, None] sq_mean = mean_**2 var_ = mean_sq - sq_mean # TODO: Why these values exactly? Because they are high relative to the datatype? diff --git a/src/scanpy/get/_kernels.py b/src/scanpy/get/_kernels.py index 3e3fb0147d..b619b5d925 100644 --- a/src/scanpy/get/_kernels.py +++ b/src/scanpy/get/_kernels.py @@ -15,7 +15,7 @@ def agg_sum_csr( indicator: CSRBase, data: CSRBase, ): - out = np.zeros((indicator.shape[0], data.shape[1])) + out = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") for cat_num in numba.prange(indicator.shape[0]): start_cat_idx = indicator.indptr[cat_num] stop_cat_idx = indicator.indptr[cat_num + 1] @@ -27,7 +27,7 @@ def agg_sum_csr( for j in range(start_obs, end_obs): col = data.indices[j] - out[cat_num, col] += data.data[j] + out[cat_num, col] += float(data.data[j]) return out @@ -36,7 +36,7 @@ def agg_sum_csc( indicator: CSRBase, data: CSCBase, ): - out = np.zeros((indicator.shape[0], data.shape[1]), dtype=data.data.dtype) + out = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64) @@ -53,6 +53,6 @@ def agg_sum_csc( cat = obs_to_cat[obs] if cat != -1: - out[cat, col] += data.data[j] + out[cat, col] += float(data.data[j]) return out From 0d125c50452241bd36cb967bf53339ca52a72db1 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 20:00:59 +0200 Subject: [PATCH 05/23] fix: preallocate out --- src/scanpy/get/_aggregated.py | 3 ++- src/scanpy/get/_kernels.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 7f0b78b9c2..cff90e33a8 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -95,8 +95,9 @@ def _sum(self, *, data: np.ndarray | CSBase, power_of_2: bool = False) -> np.nda self.indicator_matrix @ (_power(data, 2) if power_of_2 else data) ) + out = np.zeros((self.indicator_matrix.shape[0], data.shape[1]), dtype="int64" if np.issubdtype(data, np.integer) else "float64") return (agg_sum_csr if isinstance(data, CSRBase) else agg_sum_csc)( - self.indicator_matrix, (_power(data, 2) if power_of_2 else data) + self.indicator_matrix, (_power(data, 2) if power_of_2 else data), out ) def mean(self) -> Array: diff --git a/src/scanpy/get/_kernels.py b/src/scanpy/get/_kernels.py index b619b5d925..3bfeab6ed8 100644 --- a/src/scanpy/get/_kernels.py +++ b/src/scanpy/get/_kernels.py @@ -14,6 +14,7 @@ def agg_sum_csr( indicator: CSRBase, data: CSRBase, + out: np.ndarray ): out = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") for cat_num in numba.prange(indicator.shape[0]): @@ -27,7 +28,7 @@ def agg_sum_csr( for j in range(start_obs, end_obs): col = data.indices[j] - out[cat_num, col] += float(data.data[j]) + out[cat_num, col] += data.data[j] return out @@ -35,8 +36,8 @@ def agg_sum_csr( def agg_sum_csc( indicator: CSRBase, data: CSCBase, + out: np.ndarray ): - out = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64) @@ -53,6 +54,6 @@ def agg_sum_csc( cat = obs_to_cat[obs] if cat != -1: - out[cat, col] += float(data.data[j]) + out[cat, col] += data.data[j] return out From 7867c536cd7f6873338b5e4d8beac96a5e3fda9f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Apr 2026 17:53:54 +0000 Subject: [PATCH 06/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scanpy/get/_aggregated.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index cff90e33a8..4f954308a2 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -88,12 +88,13 @@ def sum(self, *, power_of_2: bool = False) -> np.ndarray: """ return self._sum(data=self.data, power_of_2=power_of_2) - def _sum(self, *, data: np.ndarray | CSBase, power_of_2: bool = False) -> np.ndarray: + def _sum( + self, *, data: np.ndarray | CSBase, power_of_2: bool = False + ) -> np.ndarray: if isinstance(data, np.ndarray): return utils.asarray( - self.indicator_matrix - @ (_power(data, 2) if power_of_2 else data) + self.indicator_matrix @ (_power(data, 2) if power_of_2 else data) ) out = np.zeros((self.indicator_matrix.shape[0], data.shape[1]), dtype="int64" if np.issubdtype(data, np.integer) else "float64") return (agg_sum_csr if isinstance(data, CSRBase) else agg_sum_csc)( From 1feca145f78c8c98928d7797d01a6b1dcd607b40 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 20:13:47 +0200 Subject: [PATCH 07/23] fix: dtypes --- src/scanpy/get/_aggregated.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 4f954308a2..0d500aa5f5 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd -from anndata import AnnData, utils +from anndata import AnnData from fast_array_utils.stats._power import power as fau_power # TODO: upstream from scipy import sparse from sklearn.utils.sparsefuncs import csc_median_axis_0 @@ -91,12 +91,15 @@ def sum(self, *, power_of_2: bool = False) -> np.ndarray: def _sum( self, *, data: np.ndarray | CSBase, power_of_2: bool = False ) -> np.ndarray: - if isinstance(data, np.ndarray): - return utils.asarray( + res = ( self.indicator_matrix @ (_power(data, 2) if power_of_2 else data) ) - out = np.zeros((self.indicator_matrix.shape[0], data.shape[1]), dtype="int64" if np.issubdtype(data, np.integer) else "float64") + if isinstance(res, CSBase): + return res.toarray() + return res + dtype = np.int64 if np.issubdtype(data.dtype, np.integer) else np.float64 + out = np.zeros((self.indicator_matrix.shape[0], data.shape[1]), dtype=dtype) return (agg_sum_csr if isinstance(data, CSRBase) else agg_sum_csc)( self.indicator_matrix, (_power(data, 2) if power_of_2 else data), out ) @@ -558,8 +561,10 @@ def sparse_indicator( mask: NDArray[np.bool] | None = None, ) -> CSRBase: if mask is None: - mask = np.broadcast_to(True, len(categorical)) # noqa: FBT003 - mask = mask.astype("uint8") + # TODO: why is this float64. This is a scanpy 2.0 problem maybe? + mask = np.broadcast_to(1.0, len(categorical)) # noqa: FBT003 + else: + mask = mask.astype("uint8") # can’t have -1s in the codes, but (as long as it’s valid), the value is ignored, so set to 0 where masked codes = np.where(mask, categorical.codes, 0) a = sparse.coo_array( From 86f40b0cd936b4976cc0540ad5da7c1e4b79db12 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 20:23:49 +0200 Subject: [PATCH 08/23] pre-commit --- src/scanpy/get/_aggregated.py | 6 ++---- src/scanpy/get/_kernels.py | 12 ++---------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 0d500aa5f5..f26b99e40d 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -92,9 +92,7 @@ def _sum( self, *, data: np.ndarray | CSBase, power_of_2: bool = False ) -> np.ndarray: if isinstance(data, np.ndarray): - res = ( - self.indicator_matrix @ (_power(data, 2) if power_of_2 else data) - ) + res = self.indicator_matrix @ (_power(data, 2) if power_of_2 else data) if isinstance(res, CSBase): return res.toarray() return res @@ -562,7 +560,7 @@ def sparse_indicator( ) -> CSRBase: if mask is None: # TODO: why is this float64. This is a scanpy 2.0 problem maybe? - mask = np.broadcast_to(1.0, len(categorical)) # noqa: FBT003 + mask = np.broadcast_to(1.0, len(categorical)) else: mask = mask.astype("uint8") # can’t have -1s in the codes, but (as long as it’s valid), the value is ignored, so set to 0 where masked diff --git a/src/scanpy/get/_kernels.py b/src/scanpy/get/_kernels.py index 3bfeab6ed8..971917d263 100644 --- a/src/scanpy/get/_kernels.py +++ b/src/scanpy/get/_kernels.py @@ -11,11 +11,7 @@ @njit -def agg_sum_csr( - indicator: CSRBase, - data: CSRBase, - out: np.ndarray -): +def agg_sum_csr(indicator: CSRBase, data: CSRBase, out: np.ndarray): out = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") for cat_num in numba.prange(indicator.shape[0]): start_cat_idx = indicator.indptr[cat_num] @@ -33,11 +29,7 @@ def agg_sum_csr( @njit -def agg_sum_csc( - indicator: CSRBase, - data: CSCBase, - out: np.ndarray -): +def agg_sum_csc(indicator: CSRBase, data: CSCBase, out: np.ndarray): obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64) From 53713d5a4c64cbb53a205ee680316c1e391c3f94 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 21:12:24 +0200 Subject: [PATCH 09/23] fix: use registry --- benchmarks/benchmarks/_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmarks/_utils.py b/benchmarks/benchmarks/_utils.py index c21f97f54f..3b7a7b9c7e 100644 --- a/benchmarks/benchmarks/_utils.py +++ b/benchmarks/benchmarks/_utils.py @@ -103,10 +103,12 @@ def bmmc(n_obs: int = 400) -> AnnData: @cache def _lung93k() -> AnnData: - path = pooch.retrieve( - url="https://figshare.com/ndownloader/files/45788454", - known_hash="md5:4f28af5ff226052443e7e0b39f3f9212", + registry = pooch.create( + path=pooch.os_cache("pooch"), + base_url="doi:10.6084/m9.figshare.25664775.v1/", ) + registry.load_registry_from_doi() + path = registry.fetch("adata.raw_compressed.h5ad") adata = sc.read_h5ad(path) assert isinstance(adata.X, CSRBase) adata.layers["counts"] = adata.X.astype(np.int32, copy=True) From 337d7983dfdc70e1e5b03c3072c572fe440feee6 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 21:25:12 +0200 Subject: [PATCH 10/23] fix: timeout --- benchmarks/asv.conf.json | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index d19b822178..4aea794c74 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -85,6 +85,7 @@ "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 // "scikit-misc": [""], }, + "default_benchmark_timeout": 500, // Combinations of libraries/python versions can be excluded/included // from the set to test. Each entry is a dictionary containing additional From c1ff9082c46d0afa9f135209da0253faecb59693 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 10 Apr 2026 13:56:35 +0200 Subject: [PATCH 11/23] chore: new sparse kernels --- src/scanpy/get/_aggregated.py | 42 ++++++++++--------- src/scanpy/get/_kernels.py | 76 +++++++++++++++++++++++++++++++++-- 2 files changed, 96 insertions(+), 22 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index f26b99e40d..604832bcf6 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -13,7 +13,7 @@ from scanpy._compat import CSBase, CSRBase, DaskArray from .._utils import _resolve_axis, get_literal_vals -from ._kernels import agg_sum_csc, agg_sum_csr +from ._kernels import agg_sum_csc, agg_sum_csr, mean_var_csc, mean_var_csr from .get import _check_mask if TYPE_CHECKING: @@ -78,7 +78,7 @@ def count_nonzero(self) -> NDArray[np.integer]: """ return self._sum(data=(self.data != 0).astype("uint8"), power_of_2=False) - def sum(self, *, power_of_2: bool = False) -> np.ndarray: + def sum(self) -> np.ndarray: """Compute the sum per feature per group of observations. Returns @@ -86,21 +86,19 @@ def sum(self, *, power_of_2: bool = False) -> np.ndarray: Array of sum. """ - return self._sum(data=self.data, power_of_2=power_of_2) - - def _sum( - self, *, data: np.ndarray | CSBase, power_of_2: bool = False - ) -> np.ndarray: - if isinstance(data, np.ndarray): - res = self.indicator_matrix @ (_power(data, 2) if power_of_2 else data) + if isinstance(self.data, np.ndarray): + res = self.indicator_matrix @ self.data if isinstance(res, CSBase): return res.toarray() return res - dtype = np.int64 if np.issubdtype(data.dtype, np.integer) else np.float64 - out = np.zeros((self.indicator_matrix.shape[0], data.shape[1]), dtype=dtype) - return (agg_sum_csr if isinstance(data, CSRBase) else agg_sum_csc)( - self.indicator_matrix, (_power(data, 2) if power_of_2 else data), out + dtype = np.int64 if np.issubdtype(self.data.dtype, np.integer) else np.float64 + out = np.zeros( + (self.indicator_matrix.shape[0], self.data.shape[1]), dtype=dtype + ) + (agg_sum_csr if isinstance(self.data, CSRBase) else agg_sum_csc)( + self.indicator_matrix, self.data, out ) + return out def mean(self) -> Array: """Compute the mean per feature per group of observations. @@ -134,11 +132,19 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: assert dof >= 0 group_counts = np.bincount(self.groupby.codes) - mean_ = self.mean() - # sparse matrices do not support ** for elementwise power. - mean_sq = self.sum(power_of_2=True) / group_counts[:, None] - sq_mean = mean_**2 - var_ = mean_sq - sq_mean + if isinstance(self.data, np.ndarray): + mean_ = self.mean() + # sparse matrices do not support ** for elementwise power. + mean_sq = (self.indicator_matrix @ _power(self.data, 2)) / group_counts[ + :, None + ] + sq_mean = mean_**2 + var_ = mean_sq - sq_mean + else: + mean_, var_ = ( + mean_var_csr if isinstance(self.data, CSRBase) else mean_var_csc + )(self.indicator_matrix, self.data) + sq_mean = mean_**2 # TODO: Why these values exactly? Because they are high relative to the datatype? # (unchanged from original code: https://github.com/scverse/anndata/pull/564) precision = 2 << (42 if self.data.dtype == np.float64 else 20) diff --git a/src/scanpy/get/_kernels.py b/src/scanpy/get/_kernels.py index 971917d263..d878fbbce0 100644 --- a/src/scanpy/get/_kernels.py +++ b/src/scanpy/get/_kernels.py @@ -7,12 +7,13 @@ from fast_array_utils.numba import njit if TYPE_CHECKING: + from numpy.typing import NDArray + from .._compat import CSCBase, CSRBase @njit -def agg_sum_csr(indicator: CSRBase, data: CSRBase, out: np.ndarray): - out = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") +def agg_sum_csr(indicator: CSRBase, data: CSRBase, out: NDArray): for cat_num in numba.prange(indicator.shape[0]): start_cat_idx = indicator.indptr[cat_num] stop_cat_idx = indicator.indptr[cat_num + 1] @@ -25,7 +26,6 @@ def agg_sum_csr(indicator: CSRBase, data: CSRBase, out: np.ndarray): for j in range(start_obs, end_obs): col = data.indices[j] out[cat_num, col] += data.data[j] - return out @njit @@ -48,4 +48,72 @@ def agg_sum_csc(indicator: CSRBase, data: CSCBase, out: np.ndarray): if cat != -1: out[cat, col] += data.data[j] - return out + +@njit +def mean_var_csr( + indicator: CSRBase, + data: CSCBase, +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: + + mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") + var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") + + for cat_num in numba.prange(indicator.shape[0]): + start_cat_idx = indicator.indptr[cat_num] + stop_cat_idx = indicator.indptr[cat_num + 1] + for row_num in range(start_cat_idx, stop_cat_idx): + obs_per_cat = indicator.indices[row_num] + + start_obs = data.indptr[obs_per_cat] + end_obs = data.indptr[obs_per_cat + 1] + + for j in range(start_obs, end_obs): + col = data.indices[j] + value = np.float64(data.data[j]) + value = data.data[j] + mean[cat_num, col] += value + var[cat_num, col] += value * value + + n_obs = stop_cat_idx - start_cat_idx + mean_cat = mean[cat_num, :] / n_obs + mean[cat_num, :] = mean_cat + var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat) + return mean, var + + +@njit +def mean_var_csc( + indicator: CSRBase, data: CSCBase +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: + + obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64) + + mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") + var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") + + for cat in range(indicator.shape[0]): + for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]): + obs_to_cat[indicator.indices[k]] = cat + + for col in numba.prange(data.shape[1]): + start = data.indptr[col] + end = data.indptr[col + 1] + + for j in range(start, end): + obs = data.indices[j] + cat = obs_to_cat[obs] + + if cat != -1: + value = np.float64(data.data[j]) + value = data.data[j] + mean[cat, col] += value + var[cat, col] += value * value + + for cat_num in numba.prange(indicator.shape[0]): + start_cat_idx = indicator.indptr[cat_num] + stop_cat_idx = indicator.indptr[cat_num + 1] + n_obs = stop_cat_idx - start_cat_idx + mean_cat = mean[cat_num, :] / n_obs + mean[cat_num, :] = mean_cat + var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat) + return mean, var From 573242368375ed4559f804b9ec3274b2d9fdd8d0 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 10 Apr 2026 14:04:11 +0200 Subject: [PATCH 12/23] fix: agg count_nonzero --- src/scanpy/get/_aggregated.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 604832bcf6..071a44fcaf 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -26,7 +26,7 @@ type AggType = ConstantDtypeAgg | Literal["mean", "var"] -class Aggregate: +class Aggregate[ArrayT: np.ndarray | CSBase]: """Functionality for generic grouping and aggregating. There is currently support for count_nonzero, sum, mean, and variance. @@ -54,7 +54,7 @@ class Aggregate: def __init__( self, groupby: pd.Categorical, - data: np.ndarray | CSBase, + data: ArrayT, *, mask: NDArray[np.bool] | None = None, ) -> None: @@ -66,7 +66,7 @@ def __init__( groupby: pd.Categorical indicator_matrix: CSRBase - data: Array + data: ArrayT def count_nonzero(self) -> NDArray[np.integer]: """Count the number of observations in each group. @@ -78,6 +78,19 @@ def count_nonzero(self) -> NDArray[np.integer]: """ return self._sum(data=(self.data != 0).astype("uint8"), power_of_2=False) + def _sum(self, data: ArrayT): + if isinstance(data, np.ndarray): + res = self.indicator_matrix @ data + if isinstance(res, CSBase): + return res.toarray() + return res + dtype = np.int64 if np.issubdtype(data.dtype, np.integer) else np.float64 + out = np.zeros((self.indicator_matrix.shape[0], data.shape[1]), dtype=dtype) + (agg_sum_csr if isinstance(data, CSRBase) else agg_sum_csc)( + self.indicator_matrix, data, out + ) + return out + def sum(self) -> np.ndarray: """Compute the sum per feature per group of observations. @@ -86,19 +99,7 @@ def sum(self) -> np.ndarray: Array of sum. """ - if isinstance(self.data, np.ndarray): - res = self.indicator_matrix @ self.data - if isinstance(res, CSBase): - return res.toarray() - return res - dtype = np.int64 if np.issubdtype(self.data.dtype, np.integer) else np.float64 - out = np.zeros( - (self.indicator_matrix.shape[0], self.data.shape[1]), dtype=dtype - ) - (agg_sum_csr if isinstance(self.data, CSRBase) else agg_sum_csc)( - self.indicator_matrix, self.data, out - ) - return out + return self._sum(self.data) def mean(self) -> Array: """Compute the mean per feature per group of observations. From 116fd07e6c710c06d191230c735c65bf58ac05a2 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 10 Apr 2026 14:05:38 +0200 Subject: [PATCH 13/23] right, count nonzero --- src/scanpy/get/_aggregated.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 071a44fcaf..0624b4e654 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -76,7 +76,7 @@ def count_nonzero(self) -> NDArray[np.integer]: Array of counts. """ - return self._sum(data=(self.data != 0).astype("uint8"), power_of_2=False) + return self._sum(data=(self.data != 0).astype("uint8")) def _sum(self, data: ArrayT): if isinstance(data, np.ndarray): @@ -136,9 +136,7 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: if isinstance(self.data, np.ndarray): mean_ = self.mean() # sparse matrices do not support ** for elementwise power. - mean_sq = (self.indicator_matrix @ _power(self.data, 2)) / group_counts[ - :, None - ] + mean_sq = self._sum(_power(self.data, 2)) / group_counts[:, None] sq_mean = mean_**2 var_ = mean_sq - sq_mean else: From 2dfd20ed00d2a0211a1b54d97cf14ddaaad3f2b5 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 10 Apr 2026 13:16:14 +0200 Subject: [PATCH 14/23] fix: use weights to calculate modularity (#4045) --- docs/release-notes/4045.fix.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/release-notes/4045.fix.md diff --git a/docs/release-notes/4045.fix.md b/docs/release-notes/4045.fix.md new file mode 100644 index 0000000000..fbbac0208f --- /dev/null +++ b/docs/release-notes/4045.fix.md @@ -0,0 +1 @@ +Make {func}`scanpy.metrics.modularity` actually use edge weights {smaller}`P Angerer` From 5cf58ad2253903b405f7be1756d6157328912130 Mon Sep 17 00:00:00 2001 From: "Lumberbot (aka Jack)" <39504233+meeseeksmachine@users.noreply.github.com> Date: Fri, 10 Apr 2026 14:16:06 +0200 Subject: [PATCH 15/23] docs: generate 1.12.1 release notes (#4050) Co-authored-by: Philipp A --- docs/release-notes/4045.fix.md | 1 - 1 file changed, 1 deletion(-) delete mode 100644 docs/release-notes/4045.fix.md diff --git a/docs/release-notes/4045.fix.md b/docs/release-notes/4045.fix.md deleted file mode 100644 index fbbac0208f..0000000000 --- a/docs/release-notes/4045.fix.md +++ /dev/null @@ -1 +0,0 @@ -Make {func}`scanpy.metrics.modularity` actually use edge weights {smaller}`P Angerer` From 407f838c1d2d3f0f1245638f79e1ee512c7af8f7 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 15 Apr 2026 20:41:06 +0200 Subject: [PATCH 16/23] chore: add benchmark --- benchmarks/benchmarks/preprocessing_counts.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index a4c2fb65ba..bddb8bf903 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -12,8 +12,10 @@ import anndata as ad import scanpy as sc +from scanpy._utils import get_literal_vals +from scanpy.get._aggregated import AggType -from ._utils import get_count_dataset +from ._utils import get_count_dataset, get_dataset if TYPE_CHECKING: from typing import Any @@ -146,3 +148,23 @@ def time_log1p(self, *_) -> None: def peakmem_log1p(self, *_) -> None: self.adata.uns.pop("log1p", None) sc.pp.log1p(self.adata) + + +class Agg: # noqa: D101 + params: tuple[AggType] = tuple(get_literal_vals(AggType)) + param_names = ("agg_name",) + + def setup_cache(self) -> None: + """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" + adata, _ = get_dataset("lung93k") + adata.write_h5ad("lung93k.h5ad") + + def setup(self, agg_name: AggType) -> None: + self.adata = ad.read_h5ad("lung93k.h5ad") + self.agg_name = agg_name + + def time_agg(self, *_) -> None: + sc.get.aggregate(self.adata, by="PatientNumber", func=self.agg_name) + + def peakmem_agg(self, *_) -> None: + sc.get.aggregate(self.adata, by="PatientNumber", func=self.agg_name) From b8ddbbb4c53affb2480e0b4d527209f14c5e65f6 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 15 Apr 2026 20:44:23 +0200 Subject: [PATCH 17/23] fix: remove timeout --- benchmarks/asv.conf.json | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index 4aea794c74..d19b822178 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -85,7 +85,6 @@ "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 // "scikit-misc": [""], }, - "default_benchmark_timeout": 500, // Combinations of libraries/python versions can be excluded/included // from the set to test. Each entry is a dictionary containing additional From 981229922bf1043f7b7422826e7d365732863d34 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 15 Apr 2026 20:59:07 +0200 Subject: [PATCH 18/23] fix: counts --- benchmarks/benchmarks/preprocessing_counts.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index bddb8bf903..9a20e7eda3 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -164,7 +164,11 @@ def setup(self, agg_name: AggType) -> None: self.agg_name = agg_name def time_agg(self, *_) -> None: - sc.get.aggregate(self.adata, by="PatientNumber", func=self.agg_name) + sc.get.aggregate( + self.adata, by="PatientNumber", func=self.agg_name, layer="counts" + ) def peakmem_agg(self, *_) -> None: - sc.get.aggregate(self.adata, by="PatientNumber", func=self.agg_name) + sc.get.aggregate( + self.adata, by="PatientNumber", func=self.agg_name, layer="counts" + ) From 4dc8771ba2d65a4e7931726c1720afb5651308bb Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 15 Apr 2026 21:51:42 +0200 Subject: [PATCH 19/23] chore: relnote --- docs/release-notes/4062.perf.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/release-notes/4062.perf.md diff --git a/docs/release-notes/4062.perf.md b/docs/release-notes/4062.perf.md new file mode 100644 index 0000000000..006c2ad9a7 --- /dev/null +++ b/docs/release-notes/4062.perf.md @@ -0,0 +1 @@ +Add `numba` kernels for mean/var/count-nonzero/sum arregation of sparse data in {func}`scanpy.get.aggregate` {smaller}`I Gold` From 47130159e402732c09631c00b388330b8bcf2afd Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 15 Apr 2026 21:54:08 +0200 Subject: [PATCH 20/23] chore: retain old dense behavior --- src/scanpy/get/_aggregated.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 0624b4e654..505937d5d1 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -62,10 +62,13 @@ def __init__( if (missing := groupby.isna()).any(): mask = mask & ~missing if mask is not None else ~missing self.indicator_matrix = sparse_indicator(groupby, mask=mask) + if isinstance(self.data, CSBase): + # TODO: Look into if this can be CSR and fast for dense + self.indicator_matrix = self.indicator_matrix.tocsr() self.data = data groupby: pd.Categorical - indicator_matrix: CSRBase + indicator_matrix: CSRBase | sparse.coo_array data: ArrayT def count_nonzero(self) -> NDArray[np.integer]: @@ -562,7 +565,7 @@ def sparse_indicator( categorical: pd.Categorical, *, mask: NDArray[np.bool] | None = None, -) -> CSRBase: +) -> sparse.coo_array: if mask is None: # TODO: why is this float64. This is a scanpy 2.0 problem maybe? mask = np.broadcast_to(1.0, len(categorical)) @@ -573,5 +576,5 @@ def sparse_indicator( a = sparse.coo_array( (mask, (codes, np.arange(len(categorical)))), shape=(len(categorical.categories), len(categorical)), - ).tocsr() + ) return a From 43c679ce5068998545550485ba2c01657675c637 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Wed, 15 Apr 2026 22:48:22 +0200 Subject: [PATCH 21/23] Update _aggregated.py --- src/scanpy/get/_aggregated.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 505937d5d1..4748cfd604 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -62,7 +62,7 @@ def __init__( if (missing := groupby.isna()).any(): mask = mask & ~missing if mask is not None else ~missing self.indicator_matrix = sparse_indicator(groupby, mask=mask) - if isinstance(self.data, CSBase): + if isinstance(data, CSBase): # TODO: Look into if this can be CSR and fast for dense self.indicator_matrix = self.indicator_matrix.tocsr() self.data = data From df6c0baa16b0a231d1ea5d438a75965fb8b3da6d Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 16 Apr 2026 16:20:20 +0200 Subject: [PATCH 22/23] style --- src/scanpy/get/_aggregated.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 4748cfd604..fe5140171b 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -566,11 +566,10 @@ def sparse_indicator( *, mask: NDArray[np.bool] | None = None, ) -> sparse.coo_array: - if mask is None: - # TODO: why is this float64. This is a scanpy 2.0 problem maybe? - mask = np.broadcast_to(1.0, len(categorical)) - else: - mask = mask.astype("uint8") + # TODO: why is this float64. This is a scanpy 2.0 problem maybe? + mask = ( + np.broadcast_to(1.0, len(categorical)) if mask is None else mask.astype("uint8") + ) # can’t have -1s in the codes, but (as long as it’s valid), the value is ignored, so set to 0 where masked codes = np.where(mask, categorical.codes, 0) a = sparse.coo_array( From 83c06680a6fae53dc63e478b88aa1cb02017f054 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 16 Apr 2026 16:23:18 +0200 Subject: [PATCH 23/23] style --- src/scanpy/get/_kernels.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/scanpy/get/_kernels.py b/src/scanpy/get/_kernels.py index d878fbbce0..4d25bd06be 100644 --- a/src/scanpy/get/_kernels.py +++ b/src/scanpy/get/_kernels.py @@ -13,7 +13,7 @@ @njit -def agg_sum_csr(indicator: CSRBase, data: CSRBase, out: NDArray): +def agg_sum_csr(indicator: CSRBase, data: CSRBase, out: NDArray) -> None: for cat_num in numba.prange(indicator.shape[0]): start_cat_idx = indicator.indptr[cat_num] stop_cat_idx = indicator.indptr[cat_num + 1] @@ -29,8 +29,7 @@ def agg_sum_csr(indicator: CSRBase, data: CSRBase, out: NDArray): @njit -def agg_sum_csc(indicator: CSRBase, data: CSCBase, out: np.ndarray): - +def agg_sum_csc(indicator: CSRBase, data: CSCBase, out: np.ndarray) -> None: obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64) for cat in range(indicator.shape[0]): @@ -54,7 +53,6 @@ def mean_var_csr( indicator: CSRBase, data: CSCBase, ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: - mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") @@ -85,7 +83,6 @@ def mean_var_csr( def mean_var_csc( indicator: CSRBase, data: CSCBase ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: - obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64) mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")