diff --git a/benchmarks/benchmarks/tools.py b/benchmarks/benchmarks/tools.py index 75bdeb2086..cbd32cb5cf 100644 --- a/benchmarks/benchmarks/tools.py +++ b/benchmarks/benchmarks/tools.py @@ -9,7 +9,7 @@ import scanpy as sc -from ._utils import pbmc68k_reduced +from ._utils import pbmc3k, pbmc68k_reduced class ToolsSuite: # noqa: D101 @@ -44,3 +44,28 @@ def time_rank_genes_groups(self) -> None: def peakmem_rank_genes_groups(self) -> None: sc.tl.rank_genes_groups(self.adata, "bulk_labels", method="wilcoxon") + + +class CombatSuite: + """Benchmark combat batch correction.""" + + def setup_cache(self) -> None: + import numpy as np + + adata = pbmc3k() + sc.pp.highly_variable_genes(adata, n_top_genes=500) + adata = adata[:, adata.var["highly_variable"]].copy() + sc.pp.scale(adata, max_value=10) + # assign cells to 3 batches deterministically + np.random.seed(0) + adata.obs["batch"] = np.random.choice(["A", "B", "C"], size=adata.n_obs) + adata.write_h5ad("adata_combat.h5ad") + + def setup(self) -> None: + self.adata = ad.read_h5ad("adata_combat.h5ad") + + def time_combat(self) -> None: + sc.pp.combat(self.adata, key="batch") + + def peakmem_combat(self) -> None: + sc.pp.combat(self.adata, key="batch") diff --git a/src/scanpy/preprocessing/_combat.py b/src/scanpy/preprocessing/_combat.py index 4600074cfd..97c8560fd6 100644 --- a/src/scanpy/preprocessing/_combat.py +++ b/src/scanpy/preprocessing/_combat.py @@ -105,28 +105,30 @@ def _standardize_data( design = _design_matrix(model, batch_key, batch_levels) + # use numpty .values extration only once to avoid pandas overhead + design_arr = design.values # compute pooled variance estimator - b_hat = np.dot(np.dot(la.inv(np.dot(design.T, design)), design.T), data.T) + b_hat = np.dot( + np.dot(la.inv(np.dot(design_arr.T, design_arr)), design_arr.T), data.values.T + ) grand_mean = np.dot((n_batches / n_array).T, b_hat[:n_batch, :]) - var_pooled = (data - np.dot(design, b_hat).T) ** 2 - var_pooled = np.dot(var_pooled, np.ones((int(n_array), 1)) / int(n_array)) + var_pooled = (data.values - np.dot(design_arr, b_hat).T) ** 2 + var_pooled = np.mean(var_pooled, axis=1, keepdims=True) # Compute the means if np.sum(var_pooled == 0) > 0: print(f"Found {np.sum(var_pooled == 0)} genes with zero variance.") - stand_mean = np.dot( - grand_mean.T.reshape((len(grand_mean), 1)), np.ones((1, int(n_array))) - ) - tmp = np.array(design.copy()) + stand_mean = grand_mean[:, np.newaxis] + tmp = design_arr.copy() tmp[:, :n_batch] = 0 - stand_mean += np.dot(tmp, b_hat).T + stand_mean = stand_mean + np.dot(tmp, b_hat).T # need to be a bit careful with the zero variance genes # just set the zero variance genes to zero in the standardized data s_data = np.where( var_pooled == 0, 0, - ((data - stand_mean) / np.dot(np.sqrt(var_pooled), np.ones((1, int(n_array))))), + (data.values - stand_mean) / np.sqrt(var_pooled), ) s_data = pd.DataFrame(s_data, index=data.index, columns=data.columns) @@ -218,7 +220,6 @@ def combat( # noqa: PLR0915 "within-batch variance. Filter these batches before running combat." ) raise ValueError(msg) - n_array = float(sum(n_batches)) # standardize across genes using a pooled variance estimator logg.info("Standardizing Data across genes.\n") @@ -271,28 +272,27 @@ def combat( # noqa: PLR0915 # we now apply the parametric adjustment to the standardized data from above # loop over all batches in the data + bayesdata_arr = bayesdata.to_numpy(copy=True) + batch_design_arr = batch_design.values for j, batch_idxs in enumerate(batch_info.values()): # we basically subtract the additive batch effect, rescale by the ratio # of multiplicative batch effect to pooled variance and add the overall gene # wise mean dsq = np.sqrt(delta_star[j, :]) - dsq = dsq.reshape((len(dsq), 1)) - denom = np.dot(dsq, np.ones((1, n_batches[j]))) - numer = np.array( - bayesdata.iloc[:, batch_idxs] - - np.dot(batch_design.iloc[batch_idxs], gamma_star).T + numer = ( + bayesdata_arr[:, batch_idxs] + - np.dot(batch_design_arr[batch_idxs], gamma_star).T ) - bayesdata.iloc[:, batch_idxs] = numer / denom + bayesdata_arr[:, batch_idxs] = numer / dsq[:, np.newaxis] - vpsq = np.sqrt(var_pooled).reshape((len(var_pooled), 1)) - bayesdata = bayesdata * np.dot(vpsq, np.ones((1, int(n_array)))) + stand_mean + bayesdata_arr = bayesdata_arr * np.sqrt(var_pooled) + stand_mean # put back into the adata object or return x = bayesdata.to_numpy().transpose() if inplace: - adata.X = x + adata.X = bayesdata_arr.T return None - return x + return bayesdata_arr.T def _it_sol( @@ -348,12 +348,7 @@ def _it_sol( # in the loop, gamma and delta are updated together. they depend on each other. we iterate until convergence. while change > conv: g_new = (t2 * n * g_hat + d_old * g_bar) / (t2 * n + d_old) - sum2 = s_data - g_new.reshape((g_new.shape[0], 1)) @ np.ones(( - 1, - s_data.shape[1], - )) - sum2 = sum2**2 - sum2 = sum2.sum(axis=1) + sum2 = ((s_data - g_new[:, np.newaxis]) ** 2).sum(axis=1) d_new = (0.5 * sum2 + b) / (n / 2.0 + a - 1.0) change = max(