From 39a29de5342434e4b9317adcbf63603010bd763c Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Thu, 16 Apr 2026 17:00:57 +0200 Subject: [PATCH] Backport PR #4062: perf: `numba` based aggregations for sparse data --- benchmarks/benchmarks/_utils.py | 8 +- benchmarks/benchmarks/preprocessing_counts.py | 28 ++++- docs/release-notes/4062.perf.md | 1 + src/scanpy/get/_aggregated.py | 78 +++++++----- src/scanpy/get/_kernels.py | 116 ++++++++++++++++++ 5 files changed, 194 insertions(+), 37 deletions(-) create mode 100644 docs/release-notes/4062.perf.md create mode 100644 src/scanpy/get/_kernels.py diff --git a/benchmarks/benchmarks/_utils.py b/benchmarks/benchmarks/_utils.py index c21f97f54f..3b7a7b9c7e 100644 --- a/benchmarks/benchmarks/_utils.py +++ b/benchmarks/benchmarks/_utils.py @@ -103,10 +103,12 @@ def bmmc(n_obs: int = 400) -> AnnData: @cache def _lung93k() -> AnnData: - path = pooch.retrieve( - url="https://figshare.com/ndownloader/files/45788454", - known_hash="md5:4f28af5ff226052443e7e0b39f3f9212", + registry = pooch.create( + path=pooch.os_cache("pooch"), + base_url="doi:10.6084/m9.figshare.25664775.v1/", ) + registry.load_registry_from_doi() + path = registry.fetch("adata.raw_compressed.h5ad") adata = sc.read_h5ad(path) assert isinstance(adata.X, CSRBase) adata.layers["counts"] = adata.X.astype(np.int32, copy=True) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index 1715589a7e..1e9c6f61f1 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -11,8 +11,10 @@ import anndata as ad import scanpy as sc +from scanpy._utils import get_literal_vals +from scanpy.get._aggregated import AggType -from ._utils import get_count_dataset +from ._utils import get_count_dataset, get_dataset if TYPE_CHECKING: from ._utils import Dataset, KeyCount @@ -109,3 +111,27 @@ def time_log1p(self, *_) -> None: def peakmem_log1p(self, *_) -> None: self.adata.uns.pop("log1p", None) sc.pp.log1p(self.adata) + + +class Agg: # noqa: D101 + params: tuple[AggType] = tuple(get_literal_vals(AggType)) + param_names = ("agg_name",) + + def setup_cache(self) -> None: + """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" + adata, _ = get_dataset("lung93k") + adata.write_h5ad("lung93k.h5ad") + + def setup(self, agg_name: AggType) -> None: + self.adata = ad.read_h5ad("lung93k.h5ad") + self.agg_name = agg_name + + def time_agg(self, *_) -> None: + sc.get.aggregate( + self.adata, by="PatientNumber", func=self.agg_name, layer="counts" + ) + + def peakmem_agg(self, *_) -> None: + sc.get.aggregate( + self.adata, by="PatientNumber", func=self.agg_name, layer="counts" + ) diff --git a/docs/release-notes/4062.perf.md b/docs/release-notes/4062.perf.md new file mode 100644 index 0000000000..006c2ad9a7 --- /dev/null +++ b/docs/release-notes/4062.perf.md @@ -0,0 +1 @@ +Add `numba` kernels for mean/var/count-nonzero/sum arregation of sparse data in {func}`scanpy.get.aggregate` {smaller}`I Gold` diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index c5750a8193..fe5140171b 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd -from anndata import AnnData, utils +from anndata import AnnData from fast_array_utils.stats._power import power as fau_power # TODO: upstream from scipy import sparse from sklearn.utils.sparsefuncs import csc_median_axis_0 @@ -13,6 +13,7 @@ from scanpy._compat import CSBase, CSRBase, DaskArray from .._utils import _resolve_axis, get_literal_vals +from ._kernels import agg_sum_csc, agg_sum_csr, mean_var_csc, mean_var_csr from .get import _check_mask if TYPE_CHECKING: @@ -25,7 +26,7 @@ type AggType = ConstantDtypeAgg | Literal["mean", "var"] -class Aggregate: +class Aggregate[ArrayT: np.ndarray | CSBase]: """Functionality for generic grouping and aggregating. There is currently support for count_nonzero, sum, mean, and variance. @@ -53,7 +54,7 @@ class Aggregate: def __init__( self, groupby: pd.Categorical, - data: Array, + data: ArrayT, *, mask: NDArray[np.bool] | None = None, ) -> None: @@ -61,11 +62,14 @@ def __init__( if (missing := groupby.isna()).any(): mask = mask & ~missing if mask is not None else ~missing self.indicator_matrix = sparse_indicator(groupby, mask=mask) + if isinstance(data, CSBase): + # TODO: Look into if this can be CSR and fast for dense + self.indicator_matrix = self.indicator_matrix.tocsr() self.data = data groupby: pd.Categorical - indicator_matrix: sparse.coo_matrix - data: Array + indicator_matrix: CSRBase | sparse.coo_array + data: ArrayT def count_nonzero(self) -> NDArray[np.integer]: """Count the number of observations in each group. @@ -75,11 +79,22 @@ def count_nonzero(self) -> NDArray[np.integer]: Array of counts. """ - # pattern = self.data._with_data(np.broadcast_to(1, len(self.data.data))) - # return self.indicator_matrix @ pattern - return utils.asarray(self.indicator_matrix @ (self.data != 0)) + return self._sum(data=(self.data != 0).astype("uint8")) + + def _sum(self, data: ArrayT): + if isinstance(data, np.ndarray): + res = self.indicator_matrix @ data + if isinstance(res, CSBase): + return res.toarray() + return res + dtype = np.int64 if np.issubdtype(data.dtype, np.integer) else np.float64 + out = np.zeros((self.indicator_matrix.shape[0], data.shape[1]), dtype=dtype) + (agg_sum_csr if isinstance(data, CSRBase) else agg_sum_csc)( + self.indicator_matrix, data, out + ) + return out - def sum(self) -> Array: + def sum(self) -> np.ndarray: """Compute the sum per feature per group of observations. Returns @@ -87,7 +102,7 @@ def sum(self) -> Array: Array of sum. """ - return utils.asarray(self.indicator_matrix @ self.data) + return self._sum(self.data) def mean(self) -> Array: """Compute the mean per feature per group of observations. @@ -97,10 +112,7 @@ def mean(self) -> Array: Array of mean. """ - return ( - utils.asarray(self.indicator_matrix @ self.data) - / np.bincount(self.groupby.codes)[:, None] - ) + return self.sum() / np.bincount(self.groupby.codes)[:, None] def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: """Compute the count, as well as mean and variance per feature, per group of observations. @@ -124,14 +136,17 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: assert dof >= 0 group_counts = np.bincount(self.groupby.codes) - mean_ = self.mean() - # sparse matrices do not support ** for elementwise power. - mean_sq = ( - utils.asarray(self.indicator_matrix @ _power(self.data, 2)) - / group_counts[:, None] - ) - sq_mean = mean_**2 - var_ = mean_sq - sq_mean + if isinstance(self.data, np.ndarray): + mean_ = self.mean() + # sparse matrices do not support ** for elementwise power. + mean_sq = self._sum(_power(self.data, 2)) / group_counts[:, None] + sq_mean = mean_**2 + var_ = mean_sq - sq_mean + else: + mean_, var_ = ( + mean_var_csr if isinstance(self.data, CSRBase) else mean_var_csc + )(self.indicator_matrix, self.data) + sq_mean = mean_**2 # TODO: Why these values exactly? Because they are high relative to the datatype? # (unchanged from original code: https://github.com/scverse/anndata/pull/564) precision = 2 << (42 if self.data.dtype == np.float64 else 20) @@ -550,18 +565,15 @@ def sparse_indicator( categorical: pd.Categorical, *, mask: NDArray[np.bool] | None = None, - weight: NDArray[np.floating] | None = None, -) -> sparse.coo_matrix: - if mask is not None and weight is None: - weight = mask.astype(np.float32) - elif mask is not None and weight is not None: - weight = mask * weight - elif mask is None and weight is None: - weight = np.broadcast_to(1.0, len(categorical)) +) -> sparse.coo_array: + # TODO: why is this float64. This is a scanpy 2.0 problem maybe? + mask = ( + np.broadcast_to(1.0, len(categorical)) if mask is None else mask.astype("uint8") + ) # can’t have -1s in the codes, but (as long as it’s valid), the value is ignored, so set to 0 where masked - codes = categorical.codes if mask is None else np.where(mask, categorical.codes, 0) - a = sparse.coo_matrix( - (weight, (codes, np.arange(len(categorical)))), + codes = np.where(mask, categorical.codes, 0) + a = sparse.coo_array( + (mask, (codes, np.arange(len(categorical)))), shape=(len(categorical.categories), len(categorical)), ) return a diff --git a/src/scanpy/get/_kernels.py b/src/scanpy/get/_kernels.py new file mode 100644 index 0000000000..4d25bd06be --- /dev/null +++ b/src/scanpy/get/_kernels.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numba +import numpy as np +from fast_array_utils.numba import njit + +if TYPE_CHECKING: + from numpy.typing import NDArray + + from .._compat import CSCBase, CSRBase + + +@njit +def agg_sum_csr(indicator: CSRBase, data: CSRBase, out: NDArray) -> None: + for cat_num in numba.prange(indicator.shape[0]): + start_cat_idx = indicator.indptr[cat_num] + stop_cat_idx = indicator.indptr[cat_num + 1] + for row_num in range(start_cat_idx, stop_cat_idx): + obs_per_cat = indicator.indices[row_num] + + start_obs = data.indptr[obs_per_cat] + end_obs = data.indptr[obs_per_cat + 1] + + for j in range(start_obs, end_obs): + col = data.indices[j] + out[cat_num, col] += data.data[j] + + +@njit +def agg_sum_csc(indicator: CSRBase, data: CSCBase, out: np.ndarray) -> None: + obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64) + + for cat in range(indicator.shape[0]): + for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]): + obs_to_cat[indicator.indices[k]] = cat + + for col in numba.prange(data.shape[1]): + start = data.indptr[col] + end = data.indptr[col + 1] + + for j in range(start, end): + obs = data.indices[j] + cat = obs_to_cat[obs] + + if cat != -1: + out[cat, col] += data.data[j] + + +@njit +def mean_var_csr( + indicator: CSRBase, + data: CSCBase, +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: + mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") + var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") + + for cat_num in numba.prange(indicator.shape[0]): + start_cat_idx = indicator.indptr[cat_num] + stop_cat_idx = indicator.indptr[cat_num + 1] + for row_num in range(start_cat_idx, stop_cat_idx): + obs_per_cat = indicator.indices[row_num] + + start_obs = data.indptr[obs_per_cat] + end_obs = data.indptr[obs_per_cat + 1] + + for j in range(start_obs, end_obs): + col = data.indices[j] + value = np.float64(data.data[j]) + value = data.data[j] + mean[cat_num, col] += value + var[cat_num, col] += value * value + + n_obs = stop_cat_idx - start_cat_idx + mean_cat = mean[cat_num, :] / n_obs + mean[cat_num, :] = mean_cat + var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat) + return mean, var + + +@njit +def mean_var_csc( + indicator: CSRBase, data: CSCBase +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: + obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64) + + mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") + var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") + + for cat in range(indicator.shape[0]): + for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]): + obs_to_cat[indicator.indices[k]] = cat + + for col in numba.prange(data.shape[1]): + start = data.indptr[col] + end = data.indptr[col + 1] + + for j in range(start, end): + obs = data.indices[j] + cat = obs_to_cat[obs] + + if cat != -1: + value = np.float64(data.data[j]) + value = data.data[j] + mean[cat, col] += value + var[cat, col] += value * value + + for cat_num in numba.prange(indicator.shape[0]): + start_cat_idx = indicator.indptr[cat_num] + stop_cat_idx = indicator.indptr[cat_num + 1] + n_obs = stop_cat_idx - start_cat_idx + mean_cat = mean[cat_num, :] / n_obs + mean[cat_num, :] = mean_cat + var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat) + return mean, var