-
Notifications
You must be signed in to change notification settings - Fork 738
perf: numba based aggregations for sparse data
#4062
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
ada140a
feat: speed up numba sums
ilan-gold 7a8aaff
fix: no use keeping booleans around
ilan-gold ce44fbe
chore: remove dead code
ilan-gold be47e37
fix: dtype issue
ilan-gold 0d125c5
fix: preallocate out
ilan-gold 7867c53
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1feca14
fix: dtypes
ilan-gold 86f40b0
pre-commit
ilan-gold 53713d5
fix: use registry
ilan-gold 337d798
fix: timeout
ilan-gold c1ff908
chore: new sparse kernels
ilan-gold 5732423
fix: agg count_nonzero
ilan-gold 116fd07
right, count nonzero
ilan-gold 2dfd20e
fix: use weights to calculate modularity (#4045)
flying-sheep 5cf58ad
docs: generate 1.12.1 release notes (#4050)
meeseeksmachine 407f838
chore: add benchmark
ilan-gold b8ddbbb
fix: remove timeout
ilan-gold 9812299
fix: counts
ilan-gold 4dc8771
chore: relnote
ilan-gold 4713015
chore: retain old dense behavior
ilan-gold 43c679c
Update _aggregated.py
ilan-gold 128c5e0
Merge branch 'main' into ig/numba_agg_main
ilan-gold df6c0ba
style
flying-sheep 83c0668
style
flying-sheep File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| Add `numba` kernels for mean/var/count-nonzero/sum arregation of sparse data in {func}`scanpy.get.aggregate` {smaller}`I Gold` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import numba | ||
| import numpy as np | ||
| from fast_array_utils.numba import njit | ||
|
|
||
| if TYPE_CHECKING: | ||
| from numpy.typing import NDArray | ||
|
|
||
| from .._compat import CSCBase, CSRBase | ||
|
|
||
|
|
||
| @njit | ||
| def agg_sum_csr(indicator: CSRBase, data: CSRBase, out: NDArray) -> None: | ||
| for cat_num in numba.prange(indicator.shape[0]): | ||
| start_cat_idx = indicator.indptr[cat_num] | ||
| stop_cat_idx = indicator.indptr[cat_num + 1] | ||
| for row_num in range(start_cat_idx, stop_cat_idx): | ||
| obs_per_cat = indicator.indices[row_num] | ||
|
|
||
| start_obs = data.indptr[obs_per_cat] | ||
| end_obs = data.indptr[obs_per_cat + 1] | ||
|
|
||
| for j in range(start_obs, end_obs): | ||
| col = data.indices[j] | ||
| out[cat_num, col] += data.data[j] | ||
|
|
||
|
|
||
| @njit | ||
| def agg_sum_csc(indicator: CSRBase, data: CSCBase, out: np.ndarray) -> None: | ||
| obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64) | ||
|
|
||
| for cat in range(indicator.shape[0]): | ||
| for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]): | ||
| obs_to_cat[indicator.indices[k]] = cat | ||
|
|
||
| for col in numba.prange(data.shape[1]): | ||
| start = data.indptr[col] | ||
| end = data.indptr[col + 1] | ||
|
|
||
| for j in range(start, end): | ||
| obs = data.indices[j] | ||
| cat = obs_to_cat[obs] | ||
|
|
||
| if cat != -1: | ||
| out[cat, col] += data.data[j] | ||
|
|
||
|
|
||
| @njit | ||
| def mean_var_csr( | ||
| indicator: CSRBase, | ||
| data: CSCBase, | ||
| ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: | ||
| mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") | ||
| var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") | ||
|
|
||
| for cat_num in numba.prange(indicator.shape[0]): | ||
| start_cat_idx = indicator.indptr[cat_num] | ||
| stop_cat_idx = indicator.indptr[cat_num + 1] | ||
| for row_num in range(start_cat_idx, stop_cat_idx): | ||
| obs_per_cat = indicator.indices[row_num] | ||
|
|
||
| start_obs = data.indptr[obs_per_cat] | ||
| end_obs = data.indptr[obs_per_cat + 1] | ||
|
|
||
| for j in range(start_obs, end_obs): | ||
| col = data.indices[j] | ||
| value = np.float64(data.data[j]) | ||
| value = data.data[j] | ||
| mean[cat_num, col] += value | ||
| var[cat_num, col] += value * value | ||
|
|
||
| n_obs = stop_cat_idx - start_cat_idx | ||
| mean_cat = mean[cat_num, :] / n_obs | ||
| mean[cat_num, :] = mean_cat | ||
| var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat) | ||
| return mean, var | ||
|
|
||
|
|
||
| @njit | ||
| def mean_var_csc( | ||
| indicator: CSRBase, data: CSCBase | ||
| ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: | ||
| obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64) | ||
|
|
||
| mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") | ||
| var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") | ||
|
|
||
| for cat in range(indicator.shape[0]): | ||
| for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]): | ||
| obs_to_cat[indicator.indices[k]] = cat | ||
|
|
||
| for col in numba.prange(data.shape[1]): | ||
| start = data.indptr[col] | ||
| end = data.indptr[col + 1] | ||
|
|
||
| for j in range(start, end): | ||
| obs = data.indices[j] | ||
| cat = obs_to_cat[obs] | ||
|
|
||
| if cat != -1: | ||
| value = np.float64(data.data[j]) | ||
| value = data.data[j] | ||
| mean[cat, col] += value | ||
| var[cat, col] += value * value | ||
|
|
||
| for cat_num in numba.prange(indicator.shape[0]): | ||
| start_cat_idx = indicator.indptr[cat_num] | ||
| stop_cat_idx = indicator.indptr[cat_num + 1] | ||
| n_obs = stop_cat_idx - start_cat_idx | ||
| mean_cat = mean[cat_num, :] / n_obs | ||
| mean[cat_num, :] = mean_cat | ||
| var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat) | ||
| return mean, var |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you think it would make sense to avoid re-executing sum by having basically
_sum_meanand_sum_mean_varand using that inaggregate_array?Or is
sumso fast that re-executing it is fine?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for now, this is probably fine. But it's a good point, no doubt! This PR was just focused on the current implementation. The perf difference between having 2 sum calls vs 1 in
mean_varwasn't that huge so it probably is "very fast" as you say