From c58f36b372b954119846c553dd76515c74f52ca9 Mon Sep 17 00:00:00 2001 From: Ilay Kavitzky Date: Fri, 17 Apr 2026 16:46:52 +0300 Subject: [PATCH 1/5] feat: Add benchmark for combat.py --- benchmarks/benchmarks/preprocessing_log.py | 8 ++++++ benchmarks/benchmarks/tools.py | 29 +++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 9633c8e208..a3c4b1d42a 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -40,6 +40,8 @@ def setup_cache(self) -> None: def setup(self, dataset, layer) -> None: self.adata = ad.read_h5ad(f"{dataset}_{layer}.h5ad") + if "X_pca" not in self.adata.obsm: + sc.pp.pca(self.adata) def time_pca(self, *_) -> None: sc.pp.pca(self.adata, svd_solver="arpack") @@ -67,6 +69,12 @@ def time_regress_out(self, *_) -> None: def peakmem_regress_out(self, *_) -> None: sc.pp.regress_out(self.adata, ["total_counts", "pct_counts_mt"]) + def time_neighbors(self, *_) -> None: + sc.pp.neighbors(self.adata, n_neighbors=10, n_pcs=40) + + def peakmem_neighbors(self, *_) -> None: + sc.pp.neighbors(self.adata, n_neighbors=10, n_pcs=40) + def time_scale(self, *_) -> None: sc.pp.scale(self.adata, max_value=10) diff --git a/benchmarks/benchmarks/tools.py b/benchmarks/benchmarks/tools.py index 75bdeb2086..a92d5eaf8c 100644 --- a/benchmarks/benchmarks/tools.py +++ b/benchmarks/benchmarks/tools.py @@ -9,7 +9,7 @@ import scanpy as sc -from ._utils import pbmc68k_reduced +from ._utils import pbmc3k, pbmc68k_reduced class ToolsSuite: # noqa: D101 @@ -44,3 +44,30 @@ def time_rank_genes_groups(self) -> None: def peakmem_rank_genes_groups(self) -> None: sc.tl.rank_genes_groups(self.adata, "bulk_labels", method="wilcoxon") + + +class CombatSuite: # noqa: D101 + """Benchmark combat batch correction.""" + + def setup_cache(self) -> None: + import numpy as np + + adata = pbmc3k() + sc.pp.highly_variable_genes(adata, n_top_genes=500) + adata = adata[:, adata.var["highly_variable"]].copy() + sc.pp.scale(adata, max_value=10) + # assign cells to 3 batches deterministically + np.random.seed(0) + adata.obs["batch"] = np.random.choice( + ["A", "B", "C"], size=adata.n_obs + ) + adata.write_h5ad("adata_combat.h5ad") + + def setup(self) -> None: + self.adata = ad.read_h5ad("adata_combat.h5ad") + + def time_combat(self) -> None: + sc.pp.combat(self.adata, key="batch") + + def peakmem_combat(self) -> None: + sc.pp.combat(self.adata, key="batch") From 3c3c4b174e64d23447f3305d6d02983f0e5b7236 Mon Sep 17 00:00:00 2001 From: Ilay Kavitzky Date: Fri, 17 Apr 2026 16:47:05 +0300 Subject: [PATCH 2/5] fit: Use numpy broadcasting & DataFrames .values --- src/scanpy/preprocessing/_combat.py | 42 +++++++++++------------------ 1 file changed, 16 insertions(+), 26 deletions(-) diff --git a/src/scanpy/preprocessing/_combat.py b/src/scanpy/preprocessing/_combat.py index 7aa38ccc86..3e0d1170b2 100644 --- a/src/scanpy/preprocessing/_combat.py +++ b/src/scanpy/preprocessing/_combat.py @@ -106,27 +106,26 @@ def _standardize_data( design = _design_matrix(model, batch_key, batch_levels) # compute pooled variance estimator - b_hat = np.dot(np.dot(la.inv(np.dot(design.T, design)), design.T), data.T) + design_arr = design.values + b_hat = np.dot(np.dot(la.inv(np.dot(design_arr.T, design_arr)), design_arr.T), data.values.T) grand_mean = np.dot((n_batches / n_array).T, b_hat[:n_batch, :]) - var_pooled = (data - np.dot(design, b_hat).T) ** 2 - var_pooled = np.dot(var_pooled, np.ones((int(n_array), 1)) / int(n_array)) + var_pooled = (data.values - np.dot(design_arr, b_hat).T) ** 2 + var_pooled = np.mean(var_pooled, axis=1, keepdims=True) # Compute the means if np.sum(var_pooled == 0) > 0: print(f"Found {np.sum(var_pooled == 0)} genes with zero variance.") - stand_mean = np.dot( - grand_mean.T.reshape((len(grand_mean), 1)), np.ones((1, int(n_array))) - ) - tmp = np.array(design.copy()) + stand_mean = grand_mean[:, np.newaxis] + tmp = design_arr.copy() tmp[:, :n_batch] = 0 - stand_mean += np.dot(tmp, b_hat).T + stand_mean = stand_mean + np.dot(tmp, b_hat).T # need to be a bit careful with the zero variance genes # just set the zero variance genes to zero in the standardized data s_data = np.where( var_pooled == 0, 0, - ((data - stand_mean) / np.dot(np.sqrt(var_pooled), np.ones((1, int(n_array))))), + (data.values - stand_mean) / np.sqrt(var_pooled), ) s_data = pd.DataFrame(s_data, index=data.index, columns=data.columns) @@ -271,27 +270,23 @@ def combat( # noqa: PLR0915 # we now apply the parametric adjustment to the standardized data from above # loop over all batches in the data + bayesdata_arr = bayesdata.values + batch_design_arr = batch_design.values for j, batch_idxs in enumerate(batch_info.values()): # we basically subtract the additive batch effect, rescale by the ratio # of multiplicative batch effect to pooled variance and add the overall gene # wise mean dsq = np.sqrt(delta_star[j, :]) - dsq = dsq.reshape((len(dsq), 1)) - denom = np.dot(dsq, np.ones((1, n_batches[j]))) - numer = np.array( - bayesdata.iloc[:, batch_idxs] - - np.dot(batch_design.iloc[batch_idxs], gamma_star).T - ) - bayesdata.iloc[:, batch_idxs] = numer / denom + numer = bayesdata_arr[:, batch_idxs] - np.dot(batch_design_arr[batch_idxs], gamma_star).T + bayesdata_arr[:, batch_idxs] = numer / dsq[:, np.newaxis] - vpsq = np.sqrt(var_pooled).reshape((len(var_pooled), 1)) - bayesdata = bayesdata * np.dot(vpsq, np.ones((1, int(n_array)))) + stand_mean + bayesdata_arr = bayesdata_arr * np.sqrt(var_pooled) + stand_mean # put back into the adata object or return if inplace: - adata.X = bayesdata.values.transpose() + adata.X = bayesdata_arr.T else: - return bayesdata.values.transpose() + return bayesdata_arr.T def _it_sol( @@ -347,12 +342,7 @@ def _it_sol( # in the loop, gamma and delta are updated together. they depend on each other. we iterate until convergence. while change > conv: g_new = (t2 * n * g_hat + d_old * g_bar) / (t2 * n + d_old) - sum2 = s_data - g_new.reshape((g_new.shape[0], 1)) @ np.ones(( - 1, - s_data.shape[1], - )) - sum2 = sum2**2 - sum2 = sum2.sum(axis=1) + sum2 = ((s_data - g_new[:, np.newaxis]) ** 2).sum(axis=1) d_new = (0.5 * sum2 + b) / (n / 2.0 + a - 1.0) change = max( From b392e1b6c3bd92ec08e5528aac3f08d8f0b36028 Mon Sep 17 00:00:00 2001 From: Ilay Kavitzky Date: Fri, 17 Apr 2026 16:59:36 +0300 Subject: [PATCH 3/5] fix: more comments --- src/scanpy/preprocessing/_combat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/scanpy/preprocessing/_combat.py b/src/scanpy/preprocessing/_combat.py index 3e0d1170b2..c0d25e3ff2 100644 --- a/src/scanpy/preprocessing/_combat.py +++ b/src/scanpy/preprocessing/_combat.py @@ -105,8 +105,9 @@ def _standardize_data( design = _design_matrix(model, batch_key, batch_levels) - # compute pooled variance estimator + # use numpty .values extration only once to avoid pandas overhead design_arr = design.values + # compute pooled variance estimator b_hat = np.dot(np.dot(la.inv(np.dot(design_arr.T, design_arr)), design_arr.T), data.values.T) grand_mean = np.dot((n_batches / n_array).T, b_hat[:n_batch, :]) var_pooled = (data.values - np.dot(design_arr, b_hat).T) ** 2 From f6436d592e867b7a4a072b338f99448637813498 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Apr 2026 14:17:08 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- benchmarks/benchmarks/tools.py | 6 ++---- src/scanpy/preprocessing/_combat.py | 9 +++++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmarks/tools.py b/benchmarks/benchmarks/tools.py index a92d5eaf8c..cbd32cb5cf 100644 --- a/benchmarks/benchmarks/tools.py +++ b/benchmarks/benchmarks/tools.py @@ -46,7 +46,7 @@ def peakmem_rank_genes_groups(self) -> None: sc.tl.rank_genes_groups(self.adata, "bulk_labels", method="wilcoxon") -class CombatSuite: # noqa: D101 +class CombatSuite: """Benchmark combat batch correction.""" def setup_cache(self) -> None: @@ -58,9 +58,7 @@ def setup_cache(self) -> None: sc.pp.scale(adata, max_value=10) # assign cells to 3 batches deterministically np.random.seed(0) - adata.obs["batch"] = np.random.choice( - ["A", "B", "C"], size=adata.n_obs - ) + adata.obs["batch"] = np.random.choice(["A", "B", "C"], size=adata.n_obs) adata.write_h5ad("adata_combat.h5ad") def setup(self) -> None: diff --git a/src/scanpy/preprocessing/_combat.py b/src/scanpy/preprocessing/_combat.py index c0d25e3ff2..2fe4bf5f6b 100644 --- a/src/scanpy/preprocessing/_combat.py +++ b/src/scanpy/preprocessing/_combat.py @@ -108,7 +108,9 @@ def _standardize_data( # use numpty .values extration only once to avoid pandas overhead design_arr = design.values # compute pooled variance estimator - b_hat = np.dot(np.dot(la.inv(np.dot(design_arr.T, design_arr)), design_arr.T), data.values.T) + b_hat = np.dot( + np.dot(la.inv(np.dot(design_arr.T, design_arr)), design_arr.T), data.values.T + ) grand_mean = np.dot((n_batches / n_array).T, b_hat[:n_batch, :]) var_pooled = (data.values - np.dot(design_arr, b_hat).T) ** 2 var_pooled = np.mean(var_pooled, axis=1, keepdims=True) @@ -278,7 +280,10 @@ def combat( # noqa: PLR0915 # of multiplicative batch effect to pooled variance and add the overall gene # wise mean dsq = np.sqrt(delta_star[j, :]) - numer = bayesdata_arr[:, batch_idxs] - np.dot(batch_design_arr[batch_idxs], gamma_star).T + numer = ( + bayesdata_arr[:, batch_idxs] + - np.dot(batch_design_arr[batch_idxs], gamma_star).T + ) bayesdata_arr[:, batch_idxs] = numer / dsq[:, np.newaxis] bayesdata_arr = bayesdata_arr * np.sqrt(var_pooled) + stand_mean From c0e36aa86bddfdb048f29d907eee45e8aa3547ee Mon Sep 17 00:00:00 2001 From: Ilay Kavitzky Date: Fri, 17 Apr 2026 17:20:30 +0300 Subject: [PATCH 5/5] fix: Removed unnecessary methods, ruff lint --- benchmarks/benchmarks/preprocessing_log.py | 8 -------- src/scanpy/preprocessing/_combat.py | 3 +-- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index a3c4b1d42a..9633c8e208 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -40,8 +40,6 @@ def setup_cache(self) -> None: def setup(self, dataset, layer) -> None: self.adata = ad.read_h5ad(f"{dataset}_{layer}.h5ad") - if "X_pca" not in self.adata.obsm: - sc.pp.pca(self.adata) def time_pca(self, *_) -> None: sc.pp.pca(self.adata, svd_solver="arpack") @@ -69,12 +67,6 @@ def time_regress_out(self, *_) -> None: def peakmem_regress_out(self, *_) -> None: sc.pp.regress_out(self.adata, ["total_counts", "pct_counts_mt"]) - def time_neighbors(self, *_) -> None: - sc.pp.neighbors(self.adata, n_neighbors=10, n_pcs=40) - - def peakmem_neighbors(self, *_) -> None: - sc.pp.neighbors(self.adata, n_neighbors=10, n_pcs=40) - def time_scale(self, *_) -> None: sc.pp.scale(self.adata, max_value=10) diff --git a/src/scanpy/preprocessing/_combat.py b/src/scanpy/preprocessing/_combat.py index 2fe4bf5f6b..eb501c7efd 100644 --- a/src/scanpy/preprocessing/_combat.py +++ b/src/scanpy/preprocessing/_combat.py @@ -220,7 +220,6 @@ def combat( # noqa: PLR0915 "within-batch variance. Filter these batches before running combat." ) raise ValueError(msg) - n_array = float(sum(n_batches)) # standardize across genes using a pooled variance estimator logg.info("Standardizing Data across genes.\n") @@ -273,7 +272,7 @@ def combat( # noqa: PLR0915 # we now apply the parametric adjustment to the standardized data from above # loop over all batches in the data - bayesdata_arr = bayesdata.values + bayesdata_arr = bayesdata.to_numpy(copy=True) batch_design_arr = batch_design.values for j, batch_idxs in enumerate(batch_info.values()): # we basically subtract the additive batch effect, rescale by the ratio