Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions benchmarks/benchmarks/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,12 @@ def bmmc(n_obs: int = 400) -> AnnData:

@cache
def _lung93k() -> AnnData:
path = pooch.retrieve(
url="https://figshare.com/ndownloader/files/45788454",
known_hash="md5:4f28af5ff226052443e7e0b39f3f9212",
registry = pooch.create(
path=pooch.os_cache("pooch"),
base_url="doi:10.6084/m9.figshare.25664775.v1/",
)
registry.load_registry_from_doi()
path = registry.fetch("adata.raw_compressed.h5ad")
adata = sc.read_h5ad(path)
assert isinstance(adata.X, CSRBase)
adata.layers["counts"] = adata.X.astype(np.int32, copy=True)
Expand Down
28 changes: 27 additions & 1 deletion benchmarks/benchmarks/preprocessing_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import anndata as ad

import scanpy as sc
from scanpy._utils import get_literal_vals
from scanpy.get._aggregated import AggType

from ._utils import get_count_dataset
from ._utils import get_count_dataset, get_dataset

if TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -146,3 +148,27 @@ def time_log1p(self, *_) -> None:
def peakmem_log1p(self, *_) -> None:
self.adata.uns.pop("log1p", None)
sc.pp.log1p(self.adata)


class Agg: # noqa: D101
params: tuple[AggType] = tuple(get_literal_vals(AggType))
param_names = ("agg_name",)

def setup_cache(self) -> None:
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
adata, _ = get_dataset("lung93k")
adata.write_h5ad("lung93k.h5ad")

def setup(self, agg_name: AggType) -> None:
self.adata = ad.read_h5ad("lung93k.h5ad")
self.agg_name = agg_name

def time_agg(self, *_) -> None:
sc.get.aggregate(
self.adata, by="PatientNumber", func=self.agg_name, layer="counts"
)

def peakmem_agg(self, *_) -> None:
sc.get.aggregate(
self.adata, by="PatientNumber", func=self.agg_name, layer="counts"
)
1 change: 1 addition & 0 deletions docs/release-notes/4062.perf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `numba` kernels for mean/var/count-nonzero/sum arregation of sparse data in {func}`scanpy.get.aggregate` {smaller}`I Gold`
78 changes: 45 additions & 33 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

import numpy as np
import pandas as pd
from anndata import AnnData, utils
from anndata import AnnData
from fast_array_utils.stats._power import power as fau_power # TODO: upstream
from scipy import sparse
from sklearn.utils.sparsefuncs import csc_median_axis_0

from scanpy._compat import CSBase, CSRBase, DaskArray

from .._utils import _resolve_axis, get_literal_vals
from ._kernels import agg_sum_csc, agg_sum_csr, mean_var_csc, mean_var_csr
from .get import _check_mask

if TYPE_CHECKING:
Expand All @@ -25,7 +26,7 @@
type AggType = ConstantDtypeAgg | Literal["mean", "var"]


class Aggregate:
class Aggregate[ArrayT: np.ndarray | CSBase]:
"""Functionality for generic grouping and aggregating.

There is currently support for count_nonzero, sum, mean, and variance.
Expand Down Expand Up @@ -53,19 +54,22 @@ class Aggregate:
def __init__(
self,
groupby: pd.Categorical,
data: Array,
data: ArrayT,
*,
mask: NDArray[np.bool] | None = None,
) -> None:
self.groupby = groupby
if (missing := groupby.isna()).any():
mask = mask & ~missing if mask is not None else ~missing
self.indicator_matrix = sparse_indicator(groupby, mask=mask)
if isinstance(data, CSBase):
# TODO: Look into if this can be CSR and fast for dense
self.indicator_matrix = self.indicator_matrix.tocsr()
self.data = data

groupby: pd.Categorical
indicator_matrix: sparse.coo_matrix
data: Array
indicator_matrix: CSRBase | sparse.coo_array
data: ArrayT

def count_nonzero(self) -> NDArray[np.integer]:
"""Count the number of observations in each group.
Expand All @@ -75,19 +79,30 @@ def count_nonzero(self) -> NDArray[np.integer]:
Array of counts.

"""
# pattern = self.data._with_data(np.broadcast_to(1, len(self.data.data)))
# return self.indicator_matrix @ pattern
return utils.asarray(self.indicator_matrix @ (self.data != 0))
return self._sum(data=(self.data != 0).astype("uint8"))

def _sum(self, data: ArrayT):
if isinstance(data, np.ndarray):
res = self.indicator_matrix @ data
if isinstance(res, CSBase):
return res.toarray()
return res
dtype = np.int64 if np.issubdtype(data.dtype, np.integer) else np.float64
out = np.zeros((self.indicator_matrix.shape[0], data.shape[1]), dtype=dtype)
(agg_sum_csr if isinstance(data, CSRBase) else agg_sum_csc)(
self.indicator_matrix, data, out
)
return out

def sum(self) -> Array:
def sum(self) -> np.ndarray:
"""Compute the sum per feature per group of observations.

Returns
-------
Array of sum.

"""
return utils.asarray(self.indicator_matrix @ self.data)
return self._sum(self.data)

def mean(self) -> Array:
"""Compute the mean per feature per group of observations.
Expand All @@ -97,10 +112,7 @@ def mean(self) -> Array:
Array of mean.

"""
return (
utils.asarray(self.indicator_matrix @ self.data)
/ np.bincount(self.groupby.codes)[:, None]
)
Comment on lines -100 to -103
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think it would make sense to avoid re-executing sum by having basically _sum_mean and _sum_mean_var and using that in aggregate_array?

Or is sum so fast that re-executing it is fine?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for now, this is probably fine. But it's a good point, no doubt! This PR was just focused on the current implementation. The perf difference between having 2 sum calls vs 1 in mean_var wasn't that huge so it probably is "very fast" as you say

return self.sum() / np.bincount(self.groupby.codes)[:, None]

def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]:
"""Compute the count, as well as mean and variance per feature, per group of observations.
Expand All @@ -124,14 +136,17 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]:
assert dof >= 0

group_counts = np.bincount(self.groupby.codes)
mean_ = self.mean()
# sparse matrices do not support ** for elementwise power.
mean_sq = (
utils.asarray(self.indicator_matrix @ _power(self.data, 2))
/ group_counts[:, None]
)
sq_mean = mean_**2
var_ = mean_sq - sq_mean
if isinstance(self.data, np.ndarray):
mean_ = self.mean()
# sparse matrices do not support ** for elementwise power.
mean_sq = self._sum(_power(self.data, 2)) / group_counts[:, None]
sq_mean = mean_**2
var_ = mean_sq - sq_mean
else:
mean_, var_ = (
mean_var_csr if isinstance(self.data, CSRBase) else mean_var_csc
)(self.indicator_matrix, self.data)
sq_mean = mean_**2
# TODO: Why these values exactly? Because they are high relative to the datatype?
# (unchanged from original code: https://github.com/scverse/anndata/pull/564)
precision = 2 << (42 if self.data.dtype == np.float64 else 20)
Expand Down Expand Up @@ -550,18 +565,15 @@ def sparse_indicator(
categorical: pd.Categorical,
*,
mask: NDArray[np.bool] | None = None,
weight: NDArray[np.floating] | None = None,
) -> sparse.coo_matrix:
if mask is not None and weight is None:
weight = mask.astype(np.float32)
elif mask is not None and weight is not None:
weight = mask * weight
elif mask is None and weight is None:
weight = np.broadcast_to(1.0, len(categorical))
) -> sparse.coo_array:
# TODO: why is this float64. This is a scanpy 2.0 problem maybe?
mask = (
np.broadcast_to(1.0, len(categorical)) if mask is None else mask.astype("uint8")
)
# can’t have -1s in the codes, but (as long as it’s valid), the value is ignored, so set to 0 where masked
codes = categorical.codes if mask is None else np.where(mask, categorical.codes, 0)
a = sparse.coo_matrix(
(weight, (codes, np.arange(len(categorical)))),
codes = np.where(mask, categorical.codes, 0)
a = sparse.coo_array(
(mask, (codes, np.arange(len(categorical)))),
shape=(len(categorical.categories), len(categorical)),
)
return a
116 changes: 116 additions & 0 deletions src/scanpy/get/_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numba
import numpy as np
from fast_array_utils.numba import njit

if TYPE_CHECKING:
from numpy.typing import NDArray

from .._compat import CSCBase, CSRBase


@njit
def agg_sum_csr(indicator: CSRBase, data: CSRBase, out: NDArray) -> None:
for cat_num in numba.prange(indicator.shape[0]):
start_cat_idx = indicator.indptr[cat_num]
stop_cat_idx = indicator.indptr[cat_num + 1]
for row_num in range(start_cat_idx, stop_cat_idx):
obs_per_cat = indicator.indices[row_num]

start_obs = data.indptr[obs_per_cat]
end_obs = data.indptr[obs_per_cat + 1]

for j in range(start_obs, end_obs):
col = data.indices[j]
out[cat_num, col] += data.data[j]


@njit
def agg_sum_csc(indicator: CSRBase, data: CSCBase, out: np.ndarray) -> None:
obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64)

for cat in range(indicator.shape[0]):
for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]):
obs_to_cat[indicator.indices[k]] = cat

for col in numba.prange(data.shape[1]):
start = data.indptr[col]
end = data.indptr[col + 1]

for j in range(start, end):
obs = data.indices[j]
cat = obs_to_cat[obs]

if cat != -1:
out[cat, col] += data.data[j]


@njit
def mean_var_csr(
indicator: CSRBase,
data: CSCBase,
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")
var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")

for cat_num in numba.prange(indicator.shape[0]):
start_cat_idx = indicator.indptr[cat_num]
stop_cat_idx = indicator.indptr[cat_num + 1]
for row_num in range(start_cat_idx, stop_cat_idx):
obs_per_cat = indicator.indices[row_num]

start_obs = data.indptr[obs_per_cat]
end_obs = data.indptr[obs_per_cat + 1]

for j in range(start_obs, end_obs):
col = data.indices[j]
value = np.float64(data.data[j])
value = data.data[j]
mean[cat_num, col] += value
var[cat_num, col] += value * value

n_obs = stop_cat_idx - start_cat_idx
mean_cat = mean[cat_num, :] / n_obs
mean[cat_num, :] = mean_cat
var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat)
return mean, var


@njit
def mean_var_csc(
indicator: CSRBase, data: CSCBase
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64)

mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")
var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")

for cat in range(indicator.shape[0]):
for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]):
obs_to_cat[indicator.indices[k]] = cat

for col in numba.prange(data.shape[1]):
start = data.indptr[col]
end = data.indptr[col + 1]

for j in range(start, end):
obs = data.indices[j]
cat = obs_to_cat[obs]

if cat != -1:
value = np.float64(data.data[j])
value = data.data[j]
mean[cat, col] += value
var[cat, col] += value * value

for cat_num in numba.prange(indicator.shape[0]):
start_cat_idx = indicator.indptr[cat_num]
stop_cat_idx = indicator.indptr[cat_num + 1]
n_obs = stop_cat_idx - start_cat_idx
mean_cat = mean[cat_num, :] / n_obs
mean[cat_num, :] = mean_cat
var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat)
return mean, var
Loading