Vectorize filter_on_target_knockdown#71
Vectorize filter_on_target_knockdown#71LeonHafner wants to merge 3 commits intoArcInstitute:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the filter_on_target_knockdown function to utilize vectorized operations with NumPy and Pandas, significantly improving performance over the previous loop-based approach. Additionally, a --verbose flag has been added to the CLI to provide detailed counts during the filtering stages. A review comment correctly identifies a potential memory bottleneck where the densification of the expression submatrix for all matched perturbations could lead to out-of-memory errors on large datasets, suggesting a more memory-efficient approach using sparse indexing.
| if sp.issparse(X): | ||
| X_sub = X[:, pert_positions].toarray() # (n_cells, n_matched) | ||
| else: | ||
| X_sub = np.asarray(X)[:, pert_positions] |
There was a problem hiding this comment.
The current implementation densifies the entire submatrix X_sub for all matched perturbations. For large datasets with many perturbations (e.g., whole-genome screens), this can lead to excessive memory consumption or Out-Of-Memory (OOM) errors. Since X is typically sparse in single-cell data, it is more efficient to avoid full densification and instead use sparse indexing or only densify the specific elements needed for each stage.
Consider keeping X_sub sparse if X is sparse, and then using np.asarray(...).flatten() when extracting specific values (like diag_expr or expr_vals) or computing means to maintain the speedup while significantly reducing the memory footprint.
Summary
filter_on_target_knockdownwith the vectorized reimplementation (measured 40x speedup on large datasets)_mean,is_on_target_knockdown,set_var_index_to_col--verboseflag to the CLI to expose per-stage cell/perturbation count outputHow it works
The new implementation avoids per-perturbation loops by:
X_sub(n_cells × n_matched_perts) for all target genes at onceX_sub[control_mask].mean(axis=0)np.bincount+ fancy indexing for the perturbation-level and cell-level filter stagesadata.copy()to the end, covering only the kept subsetOutputs were verified to be identical to the old implementation on a 385k-cell dataset.