From eecce50ba68e6274f90b7ad07e586fd125bf5ea7 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 26 Jan 2026 10:48:16 +0100 Subject: [PATCH 01/14] first draft --- src/scanpy/external/pp/__init__.py | 2 +- src/scanpy/external/pp/_harmony_integrate.py | 100 ---- src/scanpy/preprocessing/__init__.py | 2 + src/scanpy/preprocessing/_harmony/__init__.py | 520 ++++++++++++++++++ .../preprocessing/_harmony_integrate.py | 100 ++++ tests/external/test_harmony_integrate.py | 23 - tests/test_harmony.py | 162 ++++++ 7 files changed, 785 insertions(+), 124 deletions(-) delete mode 100644 src/scanpy/external/pp/_harmony_integrate.py create mode 100644 src/scanpy/preprocessing/_harmony/__init__.py create mode 100644 src/scanpy/preprocessing/_harmony_integrate.py delete mode 100644 tests/external/test_harmony_integrate.py create mode 100644 tests/test_harmony.py diff --git a/src/scanpy/external/pp/__init__.py b/src/scanpy/external/pp/__init__.py index a8b09f725d..4367a6fdfb 100644 --- a/src/scanpy/external/pp/__init__.py +++ b/src/scanpy/external/pp/__init__.py @@ -4,9 +4,9 @@ from ..._compat import deprecated from ...preprocessing import _scrublet +from ...preprocessing._harmony_integrate import harmony_integrate from ._bbknn import bbknn from ._dca import dca -from ._harmony_integrate import harmony_integrate from ._hashsolo import hashsolo from ._magic import magic from ._mnn_correct import mnn_correct diff --git a/src/scanpy/external/pp/_harmony_integrate.py b/src/scanpy/external/pp/_harmony_integrate.py deleted file mode 100644 index 4fd908d955..0000000000 --- a/src/scanpy/external/pp/_harmony_integrate.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Use harmony to integrate cells from different experiments.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import numpy as np - -from ..._compat import old_positionals -from ..._utils._doctests import doctest_needs - -if TYPE_CHECKING: - from collections.abc import Sequence - - from anndata import AnnData - - -@old_positionals("basis", "adjusted_basis") -@doctest_needs("harmonypy") -def harmony_integrate( - adata: AnnData, - key: str | Sequence[str], - *, - basis: str = "X_pca", - adjusted_basis: str = "X_pca_harmony", - **kwargs, -): - """Use harmonypy :cite:p:`Korsunsky2019` to integrate different experiments. - - Harmony :cite:p:`Korsunsky2019` is an algorithm for integrating single-cell - data from multiple experiments. This function uses the python - port of Harmony, ``harmonypy``, to integrate single-cell data - stored in an AnnData object. As Harmony works by adjusting the - principal components, this function should be run after performing - PCA but before computing the neighbor graph, as illustrated in the - example below. - - Parameters - ---------- - adata - The annotated data matrix. - key - The name of the column in ``adata.obs`` that differentiates - among experiments/batches. To integrate over two or more covariates, - you can pass multiple column names as a list. See ``vars_use`` - parameter of the ``harmonypy`` package for more details. - basis - The name of the field in ``adata.obsm`` where the PCA table is - stored. Defaults to ``'X_pca'``, which is the default for - ``sc.pp.pca()``. - adjusted_basis - The name of the field in ``adata.obsm`` where the adjusted PCA - table will be stored after running this function. Defaults to - ``X_pca_harmony``. - kwargs - Any additional arguments will be passed to - ``harmonypy.run_harmony()``. - - Returns - ------- - Updates adata with the field ``adata.obsm[obsm_out_field]``, - containing principal components adjusted by Harmony such that - different experiments are integrated. - - Example - ------- - First, load libraries and example dataset, and preprocess. - - >>> import scanpy as sc - >>> import scanpy.external as sce - >>> adata = sc.datasets.pbmc3k() - >>> sc.pp.recipe_zheng17(adata) - >>> sc.pp.pca(adata) - - We now arbitrarily assign a batch metadata variable to each cell - for the sake of example, but during real usage there would already - be a column in ``adata.obs`` giving the experiment each cell came - from. - - >>> adata.obs["batch"] = 1350 * ["a"] + 1350 * ["b"] - - Finally, run harmony. Afterwards, there will be a new table in - ``adata.obsm`` containing the adjusted PC's. - - >>> sce.pp.harmony_integrate(adata, "batch") - >>> "X_pca_harmony" in adata.obsm - True - - """ - try: - import harmonypy - except ImportError as e: - msg = "\nplease install harmonypy:\n\n\tpip install harmonypy" - raise ImportError(msg) from e - - x = adata.obsm[basis].astype(np.float64) - - harmony_out = harmonypy.run_harmony(x, adata.obs, key, **kwargs) - - adata.obsm[adjusted_basis] = harmony_out.Z_corr.T diff --git a/src/scanpy/preprocessing/__init__.py b/src/scanpy/preprocessing/__init__.py index d80412860d..95cc13a57b 100644 --- a/src/scanpy/preprocessing/__init__.py +++ b/src/scanpy/preprocessing/__init__.py @@ -6,6 +6,7 @@ from ._combat import combat from ._deprecated.highly_variable_genes import filter_genes_dispersion from ._deprecated.sampling import subsample +from ._harmony_integrate import harmony_integrate from ._highly_variable_genes import highly_variable_genes from ._normalization import normalize_total from ._pca import pca @@ -31,6 +32,7 @@ "filter_cells", "filter_genes", "filter_genes_dispersion", + "harmony_integrate", "highly_variable_genes", "log1p", "neighbors", diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py new file mode 100644 index 0000000000..282f6589a2 --- /dev/null +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -0,0 +1,520 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from scipy.sparse import csr_matrix # noqa: TID251 +from sklearn.cluster import KMeans +from tqdm.auto import tqdm + +if TYPE_CHECKING: + import pandas as pd + + +def harmonize( # noqa: PLR0913, PLR0912 + x: np.ndarray, + batch_df: pd.DataFrame, + batch_key: str | list[str], + *, + theta: float | list[float] | None = None, + sigma: float = 0.1, + n_clusters: int | None = None, + max_iter_harmony: int = 10, + max_iter_clustering: int = 200, + tol_harmony: float = 1e-4, + tol_clustering: float = 1e-5, + ridge_lambda: float = 1.0, + correction_method: str = "original", + block_proportion: float = 0.05, + random_state: int | None = 0, + verbose: bool = False, + sparse: bool = False, +) -> np.ndarray: + """ + Run Harmony batch correction algorithm. + + Parameters + ---------- + x + Data matrix (n_cells x d) - typically PCA embeddings. + batch_df + DataFrame containing batch information. + batch_key + Column name(s) in batch_df containing batch labels. + theta + Diversity penalty weight(s). Default is 2 for each batch variable. + sigma + Width of soft clustering kernel. Default 0.1. + n_clusters + Number of clusters. Default is min(100, n_cells/30). + max_iter_harmony + Maximum Harmony iterations. Default 10. + max_iter_clustering + Maximum clustering iterations per Harmony round. Default 200. + tol_harmony + Convergence tolerance for Harmony. Default 1e-4. + tol_clustering + Convergence tolerance for clustering. Default 1e-5. + ridge_lambda + Ridge regression regularization. Default 1.0. + correction_method + 'original' or 'fast'. Default 'original'. + block_proportion + Fraction of cells processed per clustering iteration. Default 0.05. + random_state + Random seed for reproducibility. + verbose + Print progress information. + sparse + Use sparse matrices for phi. Reduces memory for large datasets. + + Returns + ------- + z_corr + Batch-corrected embedding matrix (n_cells x d). + """ + if random_state is not None: + np.random.seed(random_state) + + # Ensure input is C-contiguous float array (infer dtype from x) + x = np.ascontiguousarray(x) + dtype = x.dtype + n_cells = x.shape[0] + + # Normalize input for clustering + z_norm = _normalize_rows_l2(x) + + # Process batch keys + batch_codes, n_batches = _get_batch_codes(batch_df, batch_key) + + # Build phi matrix (one-hot encoding of batches) + if sparse: + phi = _one_hot_encode_sparse(batch_codes, n_batches, dtype) + n_b = np.asarray(phi.sum(axis=0)).ravel() + else: + phi = _one_hot_encode(batch_codes, n_batches, dtype) + n_b = phi.sum(axis=0) + pr_b = (n_b / n_cells).reshape(-1, 1) + + # Set default theta + if theta is None: + theta_arr = np.ones(n_batches, dtype=dtype) * 2.0 + elif isinstance(theta, (int, float)): + theta_arr = np.ones(n_batches, dtype=dtype) * float(theta) + else: + theta_arr = np.array(theta, dtype=dtype) + theta_arr = theta_arr.reshape(1, -1) + + # Set default n_clusters + if n_clusters is None: + n_clusters = int(min(100, n_cells / 30)) + n_clusters = max(n_clusters, 2) + + # Initialize centroids and state arrays + r, e, o, objectives_harmony = _initialize_centroids( + z_norm, + phi, + pr_b, + n_clusters=n_clusters, + sigma=sigma, + theta=theta_arr, + random_state=random_state, + ) + + # Main Harmony loop + converged = False + z_hat = x.copy() + + for i in tqdm(range(max_iter_harmony), disable=not verbose): + # Clustering step + _clustering( + z_norm, + batch_codes, + n_batches, + pr_b, + r=r, + e=e, + o=o, + theta=theta_arr, + sigma=sigma, + max_iter=max_iter_clustering, + tol=tol_clustering, + block_proportion=block_proportion, + objectives_harmony=objectives_harmony, + ) + + # Correction step + if correction_method == "fast": + z_hat = _correction_fast( + x, batch_codes, n_batches, r, o, ridge_lambda=ridge_lambda + ) + else: + z_hat = _correction_original( + x, batch_codes, n_batches, r, ridge_lambda=ridge_lambda + ) + + # Normalize corrected data for next iteration + z_norm = _normalize_rows_l2(z_hat) + + # Check convergence + if _is_convergent_harmony(objectives_harmony, tol_harmony): + converged = True + if verbose: + print(f"Harmony converged in {i + 1} iterations") + break + + if not converged and verbose: + print(f"Harmony did not converge after {max_iter_harmony} iterations.") + + return z_hat + + +def _get_batch_codes( + batch_df: pd.DataFrame, + batch_key: str | list[str], +) -> tuple[np.ndarray, int]: + """Get batch codes from DataFrame.""" + if isinstance(batch_key, str): + batch_vec = batch_df[batch_key] + elif len(batch_key) == 1: + batch_vec = batch_df[batch_key[0]] + else: + df = batch_df[batch_key].astype("str") + batch_vec = df.apply(lambda row: ",".join(row), axis=1) + + batch_cat = batch_vec.astype("category") + codes = batch_cat.cat.codes.values.copy() + n_batches = len(batch_cat.cat.categories) + + return codes.astype(np.int32), n_batches + + +def _one_hot_encode( + codes: np.ndarray, + n_categories: int, + dtype: np.dtype, +) -> np.ndarray: + """One-hot encode category codes.""" + n = len(codes) + phi = np.zeros((n, n_categories), dtype=dtype) + phi[np.arange(n), codes] = 1.0 + return phi + + +def _one_hot_encode_sparse( + codes: np.ndarray, + n_categories: int, + dtype: np.dtype, +): + """One-hot encode category codes as sparse CSR matrix.""" + n = len(codes) + data = np.ones(n, dtype=dtype) + indices = codes.astype(np.int32) + indptr = np.arange(n + 1, dtype=np.int32) + return csr_matrix((data, indices, indptr), shape=(n, n_categories)) + + +def _normalize_rows_l2(x: np.ndarray) -> np.ndarray: + """L2 normalize each row of x.""" + norms = np.linalg.norm(x, axis=1, keepdims=True) + norms = np.maximum(norms, 1e-12) + return x / norms + + +def _normalize_rows_l1(r: np.ndarray) -> None: + """L1 normalize each row of r in-place (rows sum to 1).""" + row_sums = r.sum(axis=1, keepdims=True) + row_sums = np.maximum(row_sums, 1e-12) + r /= row_sums + + +def _initialize_centroids( + z_norm: np.ndarray, + phi: np.ndarray, + pr_b: np.ndarray, + *, + n_clusters: int, + sigma: float, + theta: np.ndarray, + random_state: int | None, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, list]: + """Initialize cluster centroids using K-means.""" + kmeans = KMeans( + n_clusters=n_clusters, random_state=random_state, n_init=10, max_iter=25 + ) + kmeans.fit(z_norm) + + # Centroids + y = kmeans.cluster_centers_.copy() + y_norm = _normalize_rows_l2(y) + + # Compute soft cluster assignments r + term = -2.0 / sigma + r = _compute_r(z_norm, y_norm, term) + _normalize_rows_l1(r) + + # Initialize e (expected) and o (observed) + r_sum = r.sum(axis=0) + e = pr_b @ r_sum.reshape(1, -1) + o = phi.T @ r + + # Compute initial objective + objectives_harmony: list = [] + obj = _compute_objective(y_norm, z_norm, r, theta=theta, sigma=sigma, o=o, e=e) + objectives_harmony.append(obj) + + return r, e, o, objectives_harmony + + +def _compute_r( + z: np.ndarray, + y: np.ndarray, + term: float, +) -> np.ndarray: + """Compute soft cluster assignments using NumPy dot.""" + dots = z @ y.T + return np.exp(term * (1.0 - dots)) + + +def _clustering( # noqa: PLR0913 + z_norm: np.ndarray, + batch_codes: np.ndarray, + n_batches: int, + pr_b: np.ndarray, + *, + r: np.ndarray, + e: np.ndarray, + o: np.ndarray, + theta: np.ndarray, + sigma: float, + max_iter: int, + tol: float, + block_proportion: float, + objectives_harmony: list, +) -> None: + """Run clustering iterations (modifies r, e, o in-place).""" + n_cells = z_norm.shape[0] + k = r.shape[1] + block_size = max(1, int(n_cells * block_proportion)) + term = -2.0 / sigma + + objectives_clustering = [] + + # Pre-allocate work arrays + y = np.empty((k, z_norm.shape[1]), dtype=z_norm.dtype) + y_norm = np.empty_like(y) + + for _ in range(max_iter): + # Compute cluster centroids: y = r.T @ z_norm, then normalize + np.dot(r.T, z_norm, out=y) + norms = np.linalg.norm(y, axis=1, keepdims=True) + norms = np.maximum(norms, 1e-12) + np.divide(y, norms, out=y_norm) + + # Randomly shuffle cell indices + idx_list = np.random.permutation(n_cells) + + # Process blocks + pos = 0 + while pos < n_cells: + end_pos = min(pos + block_size, n_cells) + block_idx = idx_list[pos:end_pos] + + for b in range(n_batches): + mask = batch_codes[block_idx] == b + if not np.any(mask): + continue + + cell_idx = block_idx[mask] + + # Remove old r contribution from o and e + r_old = r[cell_idx, :] + r_old_sum = r_old.sum(axis=0) + o[b, :] -= r_old_sum + e -= pr_b * r_old_sum + + # Compute new r values + dots = z_norm[cell_idx, :] @ y_norm.T + r_new = np.exp(term * (1.0 - dots)) + + # Apply penalty + penalty = ((e[b, :] + 1.0) / (o[b, :] + 1.0)) ** theta[0, b] + r_new *= penalty + + # Normalize rows to sum to 1 + row_sums = r_new.sum(axis=1, keepdims=True) + row_sums = np.maximum(row_sums, 1e-12) + r_new /= row_sums + + # Store back + r[cell_idx, :] = r_new + + # Add new r contribution to o and e + r_new_sum = r_new.sum(axis=0) + o[b, :] += r_new_sum + e += pr_b * r_new_sum + + pos = end_pos + + # Compute objective + obj = _compute_objective(y_norm, z_norm, r, theta=theta, sigma=sigma, o=o, e=e) + objectives_clustering.append(obj) + + # Check convergence + if _is_convergent_clustering(objectives_clustering, tol): + objectives_harmony.append(objectives_clustering[-1]) + break + + +def _correction_original( + x: np.ndarray, + batch_codes: np.ndarray, + n_batches: int, + r: np.ndarray, + *, + ridge_lambda: float, +) -> np.ndarray: + """Original correction method - per-cluster ridge regression.""" + _, d = x.shape + k = r.shape[1] + + # Ridge regularization matrix (don't penalize intercept) + id_mat = np.eye(n_batches + 1) + id_mat[0, 0] = 0 + lambda_mat = ridge_lambda * id_mat + + z = x.copy() + + for k_idx in range(k): + r_k = r[:, k_idx] + + r_sum_total = r_k.sum() + r_sum_per_batch = np.zeros(n_batches, dtype=x.dtype) + for b in range(n_batches): + r_sum_per_batch[b] = r_k[batch_codes == b].sum() + + phi_t_phi = np.zeros((n_batches + 1, n_batches + 1), dtype=x.dtype) + phi_t_phi[0, 0] = r_sum_total + phi_t_phi[0, 1:] = r_sum_per_batch + phi_t_phi[1:, 0] = r_sum_per_batch + phi_t_phi[1:, 1:] = np.diag(r_sum_per_batch) + phi_t_phi += lambda_mat + + phi_t_x = np.zeros((n_batches + 1, d), dtype=x.dtype) + phi_t_x[0, :] = r_k @ x + for b in range(n_batches): + mask = batch_codes == b + phi_t_x[b + 1, :] = r_k[mask] @ x[mask] + + try: + w = np.linalg.solve(phi_t_phi, phi_t_x) + except np.linalg.LinAlgError: + w = np.linalg.lstsq(phi_t_phi, phi_t_x, rcond=None)[0] + + w[0, :] = 0 + w_batch = w[batch_codes + 1, :] + z -= r_k[:, np.newaxis] * w_batch + + return z + + +def _correction_fast( + x: np.ndarray, + batch_codes: np.ndarray, + n_batches: int, + r: np.ndarray, + o: np.ndarray, + *, + ridge_lambda: float, +) -> np.ndarray: + """Fast correction method using precomputed factors.""" + _, d = x.shape + k = r.shape[1] + + z = x.copy() + p = np.eye(n_batches + 1) + + for k_idx in range(k): + o_k = o[:, k_idx] + n_k = np.sum(o_k) + + factor = 1.0 / (o_k + ridge_lambda) + c = n_k + np.sum(-factor * o_k**2) + c_inv = 1.0 / c + + p[0, 1:] = -factor * o_k + + p_t_b_inv = np.zeros((n_batches + 1, n_batches + 1)) + p_t_b_inv[0, 0] = c_inv + p_t_b_inv[1:, 1:] = np.diag(factor) + p_t_b_inv[1:, 0] = p[0, 1:] * c_inv + + inv_mat = p_t_b_inv @ p + + r_k = r[:, k_idx] + phi_t_x = np.zeros((n_batches + 1, d), dtype=x.dtype) + phi_t_x[0, :] = r_k @ x + for b in range(n_batches): + mask = batch_codes == b + phi_t_x[b + 1, :] = r_k[mask] @ x[mask] + + w = inv_mat @ phi_t_x + w[0, :] = 0 + + w_batch = w[batch_codes + 1, :] + z -= r_k[:, np.newaxis] * w_batch + + return z + + +def _compute_objective( + y_norm: np.ndarray, + z_norm: np.ndarray, + r: np.ndarray, + *, + theta: np.ndarray, + sigma: float, + o: np.ndarray, + e: np.ndarray, +) -> float: + """Compute Harmony objective function.""" + zy = z_norm @ y_norm.T + kmeans_error = np.sum(r * 2.0 * (1.0 - zy)) + + r_row_sums = r.sum(axis=1, keepdims=True) + r_normalized = r / np.clip(r_row_sums, 1e-12, None) + entropy = sigma * np.sum(r_normalized * np.log(r_normalized + 1e-12)) + + log_ratio = np.log((o + 1) / (e + 1)) + diversity_penalty = sigma * np.sum(theta @ (o * log_ratio)) + + return kmeans_error + entropy + diversity_penalty + + +def _is_convergent_harmony( + objectives: list, + tol: float, +) -> bool: + """Check Harmony convergence.""" + if len(objectives) < 2: + return False + + obj_old = objectives[-2] + obj_new = objectives[-1] + + return (obj_old - obj_new) < tol * abs(obj_old) + + +def _is_convergent_clustering( + objectives: list, + tol: float, + window_size: int = 3, +) -> bool: + """Check clustering convergence using window.""" + if len(objectives) < window_size + 1: + return False + + obj_old = sum(objectives[-window_size - 1 : -1]) + obj_new = sum(objectives[-window_size:]) + + return (obj_old - obj_new) < tol * abs(obj_old) diff --git a/src/scanpy/preprocessing/_harmony_integrate.py b/src/scanpy/preprocessing/_harmony_integrate.py new file mode 100644 index 0000000000..78249e645e --- /dev/null +++ b/src/scanpy/preprocessing/_harmony_integrate.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from typing import Literal + + from anndata import AnnData + + +def harmony_integrate( + adata: AnnData, + key: str | list[str], + *, + basis: str = "X_pca", + adjusted_basis: str = "X_pca_harmony", + dtype: type = np.float64, + correction_method: Literal["fast", "original"] = "original", + sparse: bool = False, + **kwargs, +) -> None: + """ + Integrate different experiments using the Harmony algorithm. + + This CPU implementation is based on the harmony-pytorch & rapids_singlecell version, + using NumPy for efficient computation. + + Parameters + ---------- + adata + The annotated data matrix. + key + The key(s) of the column(s) in ``adata.obs`` that differentiates + among experiments/batches. + basis + The name of the field in ``adata.obsm`` where the PCA table is + stored. Defaults to ``'X_pca'``. + adjusted_basis + The name of the field in ``adata.obsm`` where the adjusted PCA + table will be stored. Defaults to ``X_pca_harmony``. + dtype + The data type to use for Harmony computation. + correction_method + Choose which method for the correction step: ``original`` for + original method, ``fast`` for improved method. + sparse + Use sparse matrices for batch encoding. Reduces memory for large datasets. + **kwargs + Additional arguments passed to ``harmonize()``. + + Returns + ------- + Updates adata with the field ``adata.obsm[adjusted_basis]``, + containing principal components adjusted by Harmony. + """ + from ._harmony import harmonize + + # Ensure the basis exists in adata.obsm + if basis not in adata.obsm: + msg = ( + f"The specified basis '{basis}' is not available in adata.obsm. " + f"Available bases: {list(adata.obsm.keys())}" + ) + raise ValueError(msg) + + # Get the input data + input_data = adata.obsm[basis] + + # Convert to numpy array with specified dtype + try: + x = np.ascontiguousarray(input_data, dtype=dtype) + except Exception as e: + msg = ( + f"Could not convert input of type {type(input_data).__name__} " + "to NumPy array." + ) + raise TypeError(msg) from e + + # Check for NaN values + if np.isnan(x).any(): + msg = ( + "Input data contains NaN values. Please handle these before " + "running harmony_integrate." + ) + raise ValueError(msg) + + # Run Harmony + harmony_out = harmonize( + x, + adata.obs, + key, + correction_method=correction_method, + sparse=sparse, + **kwargs, + ) + + # Store result + adata.obsm[adjusted_basis] = harmony_out diff --git a/tests/external/test_harmony_integrate.py b/tests/external/test_harmony_integrate.py deleted file mode 100644 index 2844354a2f..0000000000 --- a/tests/external/test_harmony_integrate.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -import scanpy as sc -import scanpy.external as sce -from testing.scanpy._helpers.data import pbmc3k -from testing.scanpy._pytest.marks import needs - -pytestmark = [needs.harmonypy] - - -def test_harmony_integrate(): - """Test that Harmony integrate works. - - This is a very simple test that just checks to see if the Harmony - integrate wrapper succesfully added a new field to ``adata.obsm`` - and makes sure it has the same dimensions as the original PCA table. - """ - adata = pbmc3k() - sc.pp.recipe_zheng17(adata) - sc.pp.pca(adata) - adata.obs["batch"] = 1350 * ["a"] + 1350 * ["b"] - sce.pp.harmony_integrate(adata, "batch") - assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape diff --git a/tests/test_harmony.py b/tests/test_harmony.py new file mode 100644 index 0000000000..fcecacd0c5 --- /dev/null +++ b/tests/test_harmony.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import anndata as ad +import numpy as np +import pandas as pd +import pooch +import pytest +from scipy.stats import pearsonr + +import scanpy as sc +from scanpy.preprocessing import harmony_integrate + + +def _get_measure(x, base, norm): + """Compute correlation or L2 distance between arrays.""" + assert norm in ["r", "L2"] + + if norm == "r": + # Compute per-column correlation + if x.ndim == 1: + corr, _ = pearsonr(x, base) + return corr + else: + corrs = [] + for i in range(x.shape[1]): + corr, _ = pearsonr(x[:, i], base[:, i]) + corrs.append(corr) + return np.array(corrs) + # L2 distance normalized by base norm + elif x.ndim == 1: + return np.linalg.norm(x - base) / np.linalg.norm(base) + else: + dists = [] + for i in range(x.shape[1]): + dist = np.linalg.norm(x[:, i] - base[:, i]) / np.linalg.norm(base[:, i]) + dists.append(dist) + return np.array(dists) + + +@pytest.fixture +def adata_reference(): + """Load reference data from harmonypy repository.""" + x_pca_file = pooch.retrieve( + "https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/pbmc_3500_pcs.tsv.gz", + known_hash="md5:27e319b3ddcc0c00d98e70aa8e677b10", + ) + x_pca = pd.read_csv(x_pca_file, delimiter="\t") + x_pca_harmony_file = pooch.retrieve( + "https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/pbmc_3500_pcs_harmonized.tsv.gz", + known_hash="md5:a7c4ce4b98c390997c66d63d48e09221", + ) + x_pca_harmony = pd.read_csv(x_pca_harmony_file, delimiter="\t") + meta_file = pooch.retrieve( + "https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/pbmc_3500_meta.tsv.gz", + known_hash="md5:8c7ca20e926513da7cf0def1211baecb", + ) + meta = pd.read_csv(meta_file, delimiter="\t") + # Create unique index using row number + cell name + meta.index = [f"{i}_{cell}" for i, cell in enumerate(meta["cell"])] + adata = ad.AnnData( + X=None, + obs=meta, + obsm={"X_pca": x_pca.values, "harmony_org": x_pca_harmony.values}, + ) + return adata + + +@pytest.mark.parametrize("correction_method", ["fast", "original"]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_harmony_integrate(correction_method, dtype): + """Test that Harmony integrate works.""" + adata = sc.datasets.pbmc68k_reduced() + harmony_integrate( + adata, "bulk_labels", correction_method=correction_method, dtype=dtype + ) + assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_harmony_integrate_algos(dtype): + """Test that both correction methods produce similar results.""" + adata = sc.datasets.pbmc68k_reduced() + harmony_integrate(adata, "bulk_labels", correction_method="fast", dtype=dtype) + fast = adata.obsm["X_pca_harmony"].copy() + harmony_integrate(adata, "bulk_labels", correction_method="original", dtype=dtype) + slow = adata.obsm["X_pca_harmony"].copy() + assert _get_measure(fast, slow, "r").min() > 0.99 + assert _get_measure(fast, slow, "L2").max() < 0.1 + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("correction_method", ["fast", "original"]) +def test_harmony_integrate_reference(adata_reference, *, dtype, correction_method): + """Test that Harmony produces results similar to the reference implementation.""" + harmony_integrate( + adata_reference, + "donor", + correction_method=correction_method, + dtype=dtype, + max_iter_harmony=20, + ) + + assert ( + _get_measure( + adata_reference.obsm["harmony_org"], + adata_reference.obsm["X_pca_harmony"], + "L2", + ).max() + < 0.05 + ) + assert ( + _get_measure( + adata_reference.obsm["harmony_org"], + adata_reference.obsm["X_pca_harmony"], + "r", + ).min() + > 0.95 + ) + + +def test_harmony_multiple_keys(): + """Test Harmony with multiple batch keys.""" + adata = sc.datasets.pbmc68k_reduced() + # Create a second batch key + adata.obs["batch2"] = np.random.choice(["A", "B", "C"], size=adata.n_obs) + harmony_integrate(adata, ["bulk_labels", "batch2"], correction_method="original") + assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape + + +def test_harmony_custom_parameters(): + """Test Harmony with custom parameters.""" + adata = sc.datasets.pbmc68k_reduced() + harmony_integrate( + adata, + "bulk_labels", + theta=1.5, + sigma=0.15, + n_clusters=50, + max_iter_harmony=5, + ridge_lambda=0.5, + ) + assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape + + +def test_harmony_no_nan_output(): + """Test that Harmony output contains no NaN values.""" + adata = sc.datasets.pbmc68k_reduced() + harmony_integrate(adata, "bulk_labels") + assert not np.isnan(adata.obsm["X_pca_harmony"]).any() + + +def test_harmony_input_validation(): + """Test that Harmony raises errors for invalid inputs.""" + adata = sc.datasets.pbmc68k_reduced() + + # Test missing basis + with pytest.raises(ValueError, match="not available"): + harmony_integrate(adata, "bulk_labels", basis="nonexistent") + + # Test missing key + with pytest.raises(KeyError): + harmony_integrate(adata, "nonexistent_key") From f27954354d4d0850d8fff4f83949c282cdf51b17 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 26 Jan 2026 11:00:26 +0100 Subject: [PATCH 02/14] add pooch --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c61961606c..8e52640212 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ test-min = [ "pytest-randomly", "pytest-rerunfailures", "tuna", + "pooch", "dependency-groups", # for CI scripts doctests ] test = [ @@ -142,7 +143,6 @@ leiden = [ "igraph>=0.10.8", "leidenalg>=0.10.1" ] # Leiden community detect bbknn = [ "bbknn" ] # Batch balanced KNN (batch correction) magic = [ "magic-impute>=2.0.4" ] # MAGIC imputation method skmisc = [ "scikit-misc>=0.5.1" ] # highly_variable_genes method 'seurat_v3' -harmony = [ "harmonypy" ] # Harmony dataset integration scanorama = [ "scanorama" ] # Scanorama dataset integration scrublet = [ "scikit-image>=0.23.1" ] # Doublet detection with automatic thresholds # Plotting From b5527796ffe054a6078f7bd8267b56b5dae2da64 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Feb 2026 10:38:24 +0100 Subject: [PATCH 03/14] improve deprecation --- src/scanpy/external/pp/__init__.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/src/scanpy/external/pp/__init__.py b/src/scanpy/external/pp/__init__.py index 4367a6fdfb..aa3135f95f 100644 --- a/src/scanpy/external/pp/__init__.py +++ b/src/scanpy/external/pp/__init__.py @@ -3,8 +3,6 @@ from __future__ import annotations from ..._compat import deprecated -from ...preprocessing import _scrublet -from ...preprocessing._harmony_integrate import harmony_integrate from ._bbknn import bbknn from ._dca import dca from ._hashsolo import hashsolo @@ -12,17 +10,32 @@ from ._mnn_correct import mnn_correct from ._scanorama_integrate import scanorama_integrate -scrublet = deprecated("Import from sc.pp instead")(_scrublet.scrublet) -scrublet_simulate_doublets = deprecated("Import from sc.pp instead")( - _scrublet.scrublet_simulate_doublets -) - __all__ = [ "bbknn", "dca", - "harmony_integrate", "hashsolo", "magic", "mnn_correct", "scanorama_integrate", ] + + +@deprecated("Import from sc.pp instead") +def harmony_integrate(*args, **kwargs): + from ...preprocessing import harmony_integrate + + return harmony_integrate(*args, **kwargs) + + +@deprecated("Import from sc.pp instead") +def scrublet(*args, **kwargs): + from ...preprocessing import scrublet + + return scrublet(*args, **kwargs) + + +@deprecated("Import from sc.pp instead") +def scrublet_simulate_doublets(*args, **kwargs): + from ...preprocessing import scrublet_simulate_doublets + + return scrublet_simulate_doublets(*args, **kwargs) From a59be404b04c3d33bf839e1a1b3ed914566da008 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Feb 2026 10:45:23 +0100 Subject: [PATCH 04/14] docs --- docs/external/preprocessing.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/external/preprocessing.md b/docs/external/preprocessing.md index 8cbf6449a6..6702e17b22 100644 --- a/docs/external/preprocessing.md +++ b/docs/external/preprocessing.md @@ -5,6 +5,11 @@ .. currentmodule:: scanpy.external ``` +Previously found here, but now part of scanpy’s main API: +- {func}`scanpy.pp.harmony_integrate` +- {func}`scanpy.pp.scrublet` +- {func}`scanpy.pp.scrublet_simulate_doublets` + (external-data-integration)= ## Data integration @@ -14,10 +19,8 @@ :toctree: ../generated/ pp.bbknn - pp.harmony_integrate pp.mnn_correct pp.scanorama_integrate - ``` ## Sample demultiplexing @@ -39,5 +42,4 @@ Note that the fundamental limitations of imputation are still under [debate](htt pp.dca pp.magic - ``` From 94710d56c18a9b9ee194a885d8fa888989b506ce Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Feb 2026 12:16:30 +0100 Subject: [PATCH 05/14] docs --- docs/api/preprocessing.md | 10 ++++++++-- docs/release-notes/1.10.0.md | 2 +- docs/release-notes/1.11.0.md | 2 +- src/scanpy/preprocessing/_harmony/__init__.py | 20 +++++++++---------- .../preprocessing/_harmony_integrate.py | 2 +- 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/docs/api/preprocessing.md b/docs/api/preprocessing.md index 0950b9b296..5a3be50767 100644 --- a/docs/api/preprocessing.md +++ b/docs/api/preprocessing.md @@ -47,9 +47,12 @@ For visual quality control, see {func}`~scanpy.pl.highest_expr_genes` and pp.recipe_seurat ``` -## Batch effect correction +(pp-data-integration)= -Also see {ref}`data-integration`. Note that a simple batch correction method is available via {func}`pp.regress_out`. Checkout {mod}`scanpy.external` for more. +## Data integration + +Batch effect correction and other data integration. +Note that a simple batch correction method is available via {func}`pp.regress_out`. ```{eval-rst} .. autosummary:: @@ -57,8 +60,11 @@ Also see {ref}`data-integration`. Note that a simple batch correction method is :toctree: generated/ pp.combat + pp.harmony_integrate ``` +Also see {ref}`data integration tools ` and external {ref}`external data integration `. + ## Doublet detection ```{eval-rst} diff --git a/docs/release-notes/1.10.0.md b/docs/release-notes/1.10.0.md index 969633db0b..b12dd11e06 100644 --- a/docs/release-notes/1.10.0.md +++ b/docs/release-notes/1.10.0.md @@ -25,7 +25,7 @@ Some highlights: * {func}`scanpy.datasets.blobs` now accepts a `random_state` argument {pr}`2683` {smaller}`E Roellin` * {func}`scanpy.pp.pca` and {func}`scanpy.pp.regress_out` now accept a layer argument {pr}`2588` {smaller}`S Dicks` * {func}`scanpy.pp.subsample` with `copy=True` can now be called in backed mode {pr}`2624` {smaller}`E Roellin` -* {func}`scanpy.external.pp.harmony_integrate` now runs with 64 bit floats improving reproducibility {pr}`2655` {smaller}`S Dicks` +* {func}`scanpy.pp.harmony_integrate` now runs with 64 bit floats improving reproducibility {pr}`2655` {smaller}`S Dicks` * {func}`scanpy.tl.rank_genes_groups` no longer warns that it's default was changed from t-test_overestim_var to t-test {pr}`2798` {smaller}`L Heumos` * `scanpy.pp.calculate_qc_metrics` now allows `qc_vars` to be passed as a string {pr}`2859` {smaller}`N Teyssier` * {func}`scanpy.tl.leiden` and {func}`scanpy.tl.louvain` now store clustering parameters in the key provided by the `key_added` parameter instead of always writing to (or overwriting) a default key {pr}`2864` {smaller}`J Fan` diff --git a/docs/release-notes/1.11.0.md b/docs/release-notes/1.11.0.md index 875b32f364..70ceab0cdf 100644 --- a/docs/release-notes/1.11.0.md +++ b/docs/release-notes/1.11.0.md @@ -30,7 +30,7 @@ Release candidates: #### Documentation -- {guilabel}`rc1` Improve {func}`~scanpy.external.pp.harmony_integrate` docs {smaller}`D Kühl` ({pr}`3362`) +- {guilabel}`rc1` Improve {func}`~scanpy.pp.harmony_integrate` docs {smaller}`D Kühl` ({pr}`3362`) - {guilabel}`rc1` Raise {exc}`FutureWarning` when calling deprecated {mod}`scanpy.pp` functions {smaller}`P Angerer` ({pr}`3380`) - {guilabel}`rc1` {smaller}`P Angerer` ({pr}`3407`) diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py index 282f6589a2..4a01df6b5a 100644 --- a/src/scanpy/preprocessing/_harmony/__init__.py +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -7,11 +7,15 @@ from sklearn.cluster import KMeans from tqdm.auto import tqdm +from ... import logging as log +from ..._settings import settings +from ..._settings.verbosity import Verbosity + if TYPE_CHECKING: import pandas as pd -def harmonize( # noqa: PLR0913, PLR0912 +def harmonize( # noqa: PLR0913 x: np.ndarray, batch_df: pd.DataFrame, batch_key: str | list[str], @@ -27,7 +31,6 @@ def harmonize( # noqa: PLR0913, PLR0912 correction_method: str = "original", block_proportion: float = 0.05, random_state: int | None = 0, - verbose: bool = False, sparse: bool = False, ) -> np.ndarray: """ @@ -63,8 +66,6 @@ def harmonize( # noqa: PLR0913, PLR0912 Fraction of cells processed per clustering iteration. Default 0.05. random_state Random seed for reproducibility. - verbose - Print progress information. sparse Use sparse matrices for phi. Reduces memory for large datasets. @@ -125,7 +126,7 @@ def harmonize( # noqa: PLR0913, PLR0912 converged = False z_hat = x.copy() - for i in tqdm(range(max_iter_harmony), disable=not verbose): + for i in tqdm(range(max_iter_harmony), disable=settings.verbosity < Verbosity.info): # Clustering step _clustering( z_norm, @@ -159,12 +160,11 @@ def harmonize( # noqa: PLR0913, PLR0912 # Check convergence if _is_convergent_harmony(objectives_harmony, tol_harmony): converged = True - if verbose: - print(f"Harmony converged in {i + 1} iterations") + log.info(f"Harmony converged in {i + 1} iterations") break - if not converged and verbose: - print(f"Harmony did not converge after {max_iter_harmony} iterations.") + if not converged: + log.info(f"Harmony did not converge after {max_iter_harmony} iterations.") return z_hat @@ -180,7 +180,7 @@ def _get_batch_codes( batch_vec = batch_df[batch_key[0]] else: df = batch_df[batch_key].astype("str") - batch_vec = df.apply(lambda row: ",".join(row), axis=1) + batch_vec = df.apply(",".join, axis=1) batch_cat = batch_vec.astype("category") codes = batch_cat.cat.codes.values.copy() diff --git a/src/scanpy/preprocessing/_harmony_integrate.py b/src/scanpy/preprocessing/_harmony_integrate.py index 78249e645e..081264ec6e 100644 --- a/src/scanpy/preprocessing/_harmony_integrate.py +++ b/src/scanpy/preprocessing/_harmony_integrate.py @@ -60,7 +60,7 @@ def harmony_integrate( # Ensure the basis exists in adata.obsm if basis not in adata.obsm: msg = ( - f"The specified basis '{basis}' is not available in adata.obsm. " + f"The specified basis {basis!r} is not available in `adata.obsm`. " f"Available bases: {list(adata.obsm.keys())}" ) raise ValueError(msg) From 4b4d827b0484eec031439d3288ef275d6309c90e Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Feb 2026 13:01:58 +0100 Subject: [PATCH 06/14] move API over --- src/scanpy/preprocessing/_harmony/__init__.py | 50 +++----- .../preprocessing/_harmony_integrate.py | 49 +++++++- src/testing/scanpy/_pytest/__init__.py | 5 + tests/test_harmony.py | 107 ++++++++++-------- 4 files changed, 120 insertions(+), 91 deletions(-) diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py index 4a01df6b5a..62f8885a80 100644 --- a/src/scanpy/preprocessing/_harmony/__init__.py +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -12,6 +12,8 @@ from ..._settings.verbosity import Verbosity if TYPE_CHECKING: + from typing import Literal + import pandas as pd @@ -20,18 +22,18 @@ def harmonize( # noqa: PLR0913 batch_df: pd.DataFrame, batch_key: str | list[str], *, - theta: float | list[float] | None = None, - sigma: float = 0.1, - n_clusters: int | None = None, - max_iter_harmony: int = 10, - max_iter_clustering: int = 200, - tol_harmony: float = 1e-4, - tol_clustering: float = 1e-5, - ridge_lambda: float = 1.0, - correction_method: str = "original", - block_proportion: float = 0.05, - random_state: int | None = 0, - sparse: bool = False, + theta: float | list[float] | None, + sigma: float, + n_clusters: int | None, + max_iter_harmony: int, + max_iter_clustering: int, + tol_harmony: float, + tol_clustering: float, + ridge_lambda: float, + correction_method: Literal["fast", "original"], + block_proportion: float, + random_state: int | None, + sparse: bool, ) -> np.ndarray: """ Run Harmony batch correction algorithm. @@ -44,30 +46,6 @@ def harmonize( # noqa: PLR0913 DataFrame containing batch information. batch_key Column name(s) in batch_df containing batch labels. - theta - Diversity penalty weight(s). Default is 2 for each batch variable. - sigma - Width of soft clustering kernel. Default 0.1. - n_clusters - Number of clusters. Default is min(100, n_cells/30). - max_iter_harmony - Maximum Harmony iterations. Default 10. - max_iter_clustering - Maximum clustering iterations per Harmony round. Default 200. - tol_harmony - Convergence tolerance for Harmony. Default 1e-4. - tol_clustering - Convergence tolerance for clustering. Default 1e-5. - ridge_lambda - Ridge regression regularization. Default 1.0. - correction_method - 'original' or 'fast'. Default 'original'. - block_proportion - Fraction of cells processed per clustering iteration. Default 0.05. - random_state - Random seed for reproducibility. - sparse - Use sparse matrices for phi. Reduces memory for large datasets. Returns ------- diff --git a/src/scanpy/preprocessing/_harmony_integrate.py b/src/scanpy/preprocessing/_harmony_integrate.py index 081264ec6e..ca4945195f 100644 --- a/src/scanpy/preprocessing/_harmony_integrate.py +++ b/src/scanpy/preprocessing/_harmony_integrate.py @@ -8,18 +8,28 @@ from typing import Literal from anndata import AnnData + from numpy.typing import DTypeLike -def harmony_integrate( +def harmony_integrate( # noqa: PLR0913 adata: AnnData, key: str | list[str], *, basis: str = "X_pca", adjusted_basis: str = "X_pca_harmony", - dtype: type = np.float64, + dtype: DTypeLike = np.float64, + theta: float | list[float] | None = None, + sigma: float = 0.1, + n_clusters: int | None = None, + max_iter_harmony: int = 10, + max_iter_clustering: int = 200, + tol_harmony: float = 1e-4, + tol_clustering: float = 1e-5, + ridge_lambda: float = 1.0, correction_method: Literal["fast", "original"] = "original", + block_proportion: float = 0.05, + random_state: int | None = 0, sparse: bool = False, - **kwargs, ) -> None: """ Integrate different experiments using the Harmony algorithm. @@ -42,13 +52,31 @@ def harmony_integrate( table will be stored. Defaults to ``X_pca_harmony``. dtype The data type to use for Harmony computation. + theta + Diversity penalty weight(s). Default is 2 for each batch variable. + sigma + Width of soft clustering kernel. Default 0.1. + n_clusters + Number of clusters. Default is min(100, n_cells/30). + max_iter_harmony + Maximum Harmony iterations. Default 10. + max_iter_clustering + Maximum clustering iterations per Harmony round. Default 200. + tol_harmony + Convergence tolerance for Harmony. Default 1e-4. + tol_clustering + Convergence tolerance for clustering. Default 1e-5. + ridge_lambda + Ridge regression regularization. Default 1.0. correction_method Choose which method for the correction step: ``original`` for original method, ``fast`` for improved method. + block_proportion + Fraction of cells processed per clustering iteration. Default 0.05. + random_state + Random seed for reproducibility. sparse Use sparse matrices for batch encoding. Reduces memory for large datasets. - **kwargs - Additional arguments passed to ``harmonize()``. Returns ------- @@ -91,9 +119,18 @@ def harmony_integrate( x, adata.obs, key, + theta=theta, + sigma=sigma, + n_clusters=n_clusters, + max_iter_harmony=max_iter_harmony, + max_iter_clustering=max_iter_clustering, + tol_harmony=tol_harmony, + tol_clustering=tol_clustering, + ridge_lambda=ridge_lambda, correction_method=correction_method, + block_proportion=block_proportion, + random_state=random_state, sparse=sparse, - **kwargs, ) # Store result diff --git a/src/testing/scanpy/_pytest/__init__.py b/src/testing/scanpy/_pytest/__init__.py index 04777f6aef..569d47b3f9 100644 --- a/src/testing/scanpy/_pytest/__init__.py +++ b/src/testing/scanpy/_pytest/__init__.py @@ -7,6 +7,7 @@ from types import MappingProxyType from typing import TYPE_CHECKING +import pooch import pytest from packaging.version import Version @@ -26,6 +27,7 @@ def original_settings( request: pytest.FixtureRequest, cache: pytest.Cache, tmp_path_factory: pytest.TempPathFactory, + monkeypatch: pytest.MonkeyPatch, ) -> Generator[Mapping[str, object], None, None]: """Switch to agg backend, reset settings, and close all figures at teardown.""" # make sure seaborn is imported and did its thing @@ -51,6 +53,9 @@ def original_settings( cache.mkdir("debug") # reuse data files between test runs (unless overwritten in the test) sc.settings.datasetdir = cache.mkdir("scanpy-data") + pooch.os_cache = pooch.utils.os_cache = pooch.core.os_cache = lambda p: ( + sc.settings.datasetdir / p + ) # create new writedir for each test run sc.settings.writedir = tmp_path_factory.mktemp("scanpy_write") diff --git a/tests/test_harmony.py b/tests/test_harmony.py index fcecacd0c5..b8e2960581 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -1,44 +1,51 @@ from __future__ import annotations -import anndata as ad +from typing import TYPE_CHECKING + import numpy as np import pandas as pd import pooch import pytest +from anndata import AnnData from scipy.stats import pearsonr import scanpy as sc from scanpy.preprocessing import harmony_integrate +if TYPE_CHECKING: + from typing import Literal -def _get_measure(x, base, norm): - """Compute correlation or L2 distance between arrays.""" - assert norm in ["r", "L2"] + from numpy.typing import DTypeLike + +def _get_measure( + x: np.ndarray, base: np.ndarray, norm: Literal["r", "L2"] +) -> np.ndarray: + """Compute correlation or L2 distance between arrays.""" if norm == "r": # Compute per-column correlation if x.ndim == 1: corr, _ = pearsonr(x, base) return corr - else: - corrs = [] - for i in range(x.shape[1]): - corr, _ = pearsonr(x[:, i], base[:, i]) - corrs.append(corr) - return np.array(corrs) + corrs = [] + for i in range(x.shape[1]): + corr, _ = pearsonr(x[:, i], base[:, i]) + corrs.append(corr) + return np.array(corrs) + + assert norm == "L2" # L2 distance normalized by base norm - elif x.ndim == 1: + if x.ndim == 1: return np.linalg.norm(x - base) / np.linalg.norm(base) - else: - dists = [] - for i in range(x.shape[1]): - dist = np.linalg.norm(x[:, i] - base[:, i]) / np.linalg.norm(base[:, i]) - dists.append(dist) - return np.array(dists) + dists = [] + for i in range(x.shape[1]): + dist = np.linalg.norm(x[:, i] - base[:, i]) / np.linalg.norm(base[:, i]) + dists.append(dist) + return np.array(dists) @pytest.fixture -def adata_reference(): +def adata_reference() -> AnnData: """Load reference data from harmonypy repository.""" x_pca_file = pooch.retrieve( "https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/pbmc_3500_pcs.tsv.gz", @@ -57,7 +64,7 @@ def adata_reference(): meta = pd.read_csv(meta_file, delimiter="\t") # Create unique index using row number + cell name meta.index = [f"{i}_{cell}" for i, cell in enumerate(meta["cell"])] - adata = ad.AnnData( + adata = AnnData( X=None, obs=meta, obsm={"X_pca": x_pca.values, "harmony_org": x_pca_harmony.values}, @@ -67,30 +74,44 @@ def adata_reference(): @pytest.mark.parametrize("correction_method", ["fast", "original"]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_harmony_integrate(correction_method, dtype): +def test_harmony_integrate( + correction_method: Literal["fast", "original"], dtype: DTypeLike +) -> None: """Test that Harmony integrate works.""" adata = sc.datasets.pbmc68k_reduced() + harmony_integrate( adata, "bulk_labels", correction_method=correction_method, dtype=dtype ) + assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_harmony_integrate_algos(dtype): +def test_harmony_integrate_algos(subtests: pytest.Subtests, dtype: DTypeLike) -> None: """Test that both correction methods produce similar results.""" adata = sc.datasets.pbmc68k_reduced() + harmony_integrate(adata, "bulk_labels", correction_method="fast", dtype=dtype) fast = adata.obsm["X_pca_harmony"].copy() harmony_integrate(adata, "bulk_labels", correction_method="original", dtype=dtype) slow = adata.obsm["X_pca_harmony"].copy() - assert _get_measure(fast, slow, "r").min() > 0.99 - assert _get_measure(fast, slow, "L2").max() < 0.1 + + with subtests.test("r"): + assert _get_measure(fast, slow, "r").min() > 0.99 + with subtests.test("L2"): + assert _get_measure(fast, slow, "L2").max() < 0.1 @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize("correction_method", ["fast", "original"]) -def test_harmony_integrate_reference(adata_reference, *, dtype, correction_method): +def test_harmony_integrate_reference( + *, + subtests: pytest.Subtests, + adata_reference: AnnData, + dtype: DTypeLike, + correction_method: Literal["fast", "original"], +) -> None: """Test that Harmony produces results similar to the reference implementation.""" harmony_integrate( adata_reference, @@ -99,35 +120,26 @@ def test_harmony_integrate_reference(adata_reference, *, dtype, correction_metho dtype=dtype, max_iter_harmony=20, ) + x, base = adata_reference.obsm["harmony_org"], adata_reference.obsm["X_pca_harmony"] - assert ( - _get_measure( - adata_reference.obsm["harmony_org"], - adata_reference.obsm["X_pca_harmony"], - "L2", - ).max() - < 0.05 - ) - assert ( - _get_measure( - adata_reference.obsm["harmony_org"], - adata_reference.obsm["X_pca_harmony"], - "r", - ).min() - > 0.95 - ) + with subtests.test("r"): + assert _get_measure(x, base, "r").min() > 0.95 + with subtests.test("L2"): + assert _get_measure(x, base, "L2").max() < 0.05 -def test_harmony_multiple_keys(): +def test_harmony_multiple_keys() -> None: """Test Harmony with multiple batch keys.""" adata = sc.datasets.pbmc68k_reduced() # Create a second batch key adata.obs["batch2"] = np.random.choice(["A", "B", "C"], size=adata.n_obs) + harmony_integrate(adata, ["bulk_labels", "batch2"], correction_method="original") + assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape -def test_harmony_custom_parameters(): +def test_harmony_custom_parameters() -> None: """Test Harmony with custom parameters.""" adata = sc.datasets.pbmc68k_reduced() harmony_integrate( @@ -142,21 +154,18 @@ def test_harmony_custom_parameters(): assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape -def test_harmony_no_nan_output(): +def test_harmony_no_nan_output() -> None: """Test that Harmony output contains no NaN values.""" adata = sc.datasets.pbmc68k_reduced() harmony_integrate(adata, "bulk_labels") assert not np.isnan(adata.obsm["X_pca_harmony"]).any() -def test_harmony_input_validation(): +def test_harmony_input_validation(subtests) -> None: """Test that Harmony raises errors for invalid inputs.""" adata = sc.datasets.pbmc68k_reduced() - # Test missing basis - with pytest.raises(ValueError, match="not available"): + with subtests.test("no basis"), pytest.raises(ValueError, match="not available"): harmony_integrate(adata, "bulk_labels", basis="nonexistent") - - # Test missing key - with pytest.raises(KeyError): + with subtests.test("no key"), pytest.raises(KeyError): harmony_integrate(adata, "nonexistent_key") From 4627d62c4eb111df65a5e1c46a7a0120ec941050 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Feb 2026 14:27:26 +0100 Subject: [PATCH 07/14] faster --- tests/test_harmony.py | 83 ++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 44 deletions(-) diff --git a/tests/test_harmony.py b/tests/test_harmony.py index b8e2960581..5d90fa146e 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -9,8 +9,8 @@ from anndata import AnnData from scipy.stats import pearsonr -import scanpy as sc from scanpy.preprocessing import harmony_integrate +from testing.scanpy._helpers.data import pbmc68k_reduced if TYPE_CHECKING: from typing import Literal @@ -18,56 +18,53 @@ from numpy.typing import DTypeLike +DATA = dict( + pca=("pbmc_3500_pcs.tsv.gz", "md5:27e319b3ddcc0c00d98e70aa8e677b10"), + pca_harmonized=( + "pbmc_3500_pcs_harmonized.tsv.gz", + "md5:a7c4ce4b98c390997c66d63d48e09221", + ), + meta=("pbmc_3500_meta.tsv.gz", "md5:8c7ca20e926513da7cf0def1211baecb"), +) + + def _get_measure( x: np.ndarray, base: np.ndarray, norm: Literal["r", "L2"] ) -> np.ndarray: """Compute correlation or L2 distance between arrays.""" - if norm == "r": - # Compute per-column correlation + if norm == "r": # Compute per-column correlation if x.ndim == 1: corr, _ = pearsonr(x, base) return corr - corrs = [] - for i in range(x.shape[1]): - corr, _ = pearsonr(x[:, i], base[:, i]) - corrs.append(corr) - return np.array(corrs) - - assert norm == "L2" - # L2 distance normalized by base norm - if x.ndim == 1: - return np.linalg.norm(x - base) / np.linalg.norm(base) - dists = [] - for i in range(x.shape[1]): - dist = np.linalg.norm(x[:, i] - base[:, i]) / np.linalg.norm(base[:, i]) - dists.append(dist) - return np.array(dists) + return np.array([pearsonr(x[:, i], base[:, i])[0] for i in range(x.shape[1])]) + if norm == "L2": + # L2 distance normalized by base norm + if x.ndim == 1: + return np.linalg.norm(x - base) / np.linalg.norm(base) + return np.array([ + np.linalg.norm(x[:, i] - base[:, i]) / np.linalg.norm(base[:, i]) + for i in range(x.shape[1]) + ]) + pytest.fail(f"Unknown {norm=!r}") @pytest.fixture def adata_reference() -> AnnData: """Load reference data from harmonypy repository.""" - x_pca_file = pooch.retrieve( - "https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/pbmc_3500_pcs.tsv.gz", - known_hash="md5:27e319b3ddcc0c00d98e70aa8e677b10", - ) - x_pca = pd.read_csv(x_pca_file, delimiter="\t") - x_pca_harmony_file = pooch.retrieve( - "https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/pbmc_3500_pcs_harmonized.tsv.gz", - known_hash="md5:a7c4ce4b98c390997c66d63d48e09221", - ) - x_pca_harmony = pd.read_csv(x_pca_harmony_file, delimiter="\t") - meta_file = pooch.retrieve( - "https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/pbmc_3500_meta.tsv.gz", - known_hash="md5:8c7ca20e926513da7cf0def1211baecb", - ) - meta = pd.read_csv(meta_file, delimiter="\t") + paths = { + f: pooch.retrieve( + f"https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/{name}", + known_hash=hash_, + ) + for f, (name, hash_) in DATA.items() + } + dfs = {f: pd.read_csv(path, delimiter="\t") for f, path in paths.items()} # Create unique index using row number + cell name - meta.index = [f"{i}_{cell}" for i, cell in enumerate(meta["cell"])] + dfs["meta"].index = [f"{i}_{cell}" for i, cell in enumerate(dfs["meta"]["cell"])] adata = AnnData( X=None, - obs=meta, - obsm={"X_pca": x_pca.values, "harmony_org": x_pca_harmony.values}, + obs=dfs["meta"], + obsm={"X_pca": dfs["pca"].values, "harmony_org": dfs["pca_harmonized"].values}, ) return adata @@ -78,19 +75,17 @@ def test_harmony_integrate( correction_method: Literal["fast", "original"], dtype: DTypeLike ) -> None: """Test that Harmony integrate works.""" - adata = sc.datasets.pbmc68k_reduced() - + adata = pbmc68k_reduced() harmony_integrate( adata, "bulk_labels", correction_method=correction_method, dtype=dtype ) - assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_harmony_integrate_algos(subtests: pytest.Subtests, dtype: DTypeLike) -> None: """Test that both correction methods produce similar results.""" - adata = sc.datasets.pbmc68k_reduced() + adata = pbmc68k_reduced() harmony_integrate(adata, "bulk_labels", correction_method="fast", dtype=dtype) fast = adata.obsm["X_pca_harmony"].copy() @@ -130,7 +125,7 @@ def test_harmony_integrate_reference( def test_harmony_multiple_keys() -> None: """Test Harmony with multiple batch keys.""" - adata = sc.datasets.pbmc68k_reduced() + adata = pbmc68k_reduced() # Create a second batch key adata.obs["batch2"] = np.random.choice(["A", "B", "C"], size=adata.n_obs) @@ -141,7 +136,7 @@ def test_harmony_multiple_keys() -> None: def test_harmony_custom_parameters() -> None: """Test Harmony with custom parameters.""" - adata = sc.datasets.pbmc68k_reduced() + adata = pbmc68k_reduced() harmony_integrate( adata, "bulk_labels", @@ -156,14 +151,14 @@ def test_harmony_custom_parameters() -> None: def test_harmony_no_nan_output() -> None: """Test that Harmony output contains no NaN values.""" - adata = sc.datasets.pbmc68k_reduced() + adata = pbmc68k_reduced() harmony_integrate(adata, "bulk_labels") assert not np.isnan(adata.obsm["X_pca_harmony"]).any() def test_harmony_input_validation(subtests) -> None: """Test that Harmony raises errors for invalid inputs.""" - adata = sc.datasets.pbmc68k_reduced() + adata = pbmc68k_reduced() with subtests.test("no basis"), pytest.raises(ValueError, match="not available"): harmony_integrate(adata, "bulk_labels", basis="nonexistent") From 3d53e3796310d133c0e8ce7e8e331c4c4cec4c09 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Feb 2026 15:45:55 +0100 Subject: [PATCH 08/14] refactor --- src/scanpy/preprocessing/_harmony/__init__.py | 295 ++++++++++-------- .../preprocessing/_harmony_integrate.py | 11 +- 2 files changed, 172 insertions(+), 134 deletions(-) diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py index 62f8885a80..ceb6818a76 100644 --- a/src/scanpy/preprocessing/_harmony/__init__.py +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import KW_ONLY, InitVar, dataclass, field from typing import TYPE_CHECKING import numpy as np @@ -12,31 +13,17 @@ from ..._settings.verbosity import Verbosity if TYPE_CHECKING: + from collections.abc import Sequence from typing import Literal import pandas as pd + from ..._compat import CSBase -def harmonize( # noqa: PLR0913 - x: np.ndarray, - batch_df: pd.DataFrame, - batch_key: str | list[str], - *, - theta: float | list[float] | None, - sigma: float, - n_clusters: int | None, - max_iter_harmony: int, - max_iter_clustering: int, - tol_harmony: float, - tol_clustering: float, - ridge_lambda: float, - correction_method: Literal["fast", "original"], - block_proportion: float, - random_state: int | None, - sparse: bool, -) -> np.ndarray: - """ - Run Harmony batch correction algorithm. + +@dataclass +class Harmony: + """Harmony batch correction algorithm. Parameters ---------- @@ -46,110 +33,173 @@ def harmonize( # noqa: PLR0913 DataFrame containing batch information. batch_key Column name(s) in batch_df containing batch labels. - - Returns - ------- - z_corr - Batch-corrected embedding matrix (n_cells x d). """ - if random_state is not None: - np.random.seed(random_state) - # Ensure input is C-contiguous float array (infer dtype from x) - x = np.ascontiguousarray(x) - dtype = x.dtype - n_cells = x.shape[0] - - # Normalize input for clustering - z_norm = _normalize_rows_l2(x) - - # Process batch keys - batch_codes, n_batches = _get_batch_codes(batch_df, batch_key) + batch_df: InitVar[pd.DataFrame] + batch_key: InitVar[str | Sequence[str]] + _: KW_ONLY + theta: float | Sequence[float] | None + sigma: float + n_clusters: int | None + max_iter_harmony: int + max_iter_clustering: int + tol_harmony: float + tol_clustering: float + ridge_lambda: float + correction_method: Literal["fast", "original"] + block_proportion: float + random_state: int | None + sparse: bool + + batch_codes: np.ndarray = field(init=False) + n_batches: int = field(init=False) + + def __post_init__( + self, batch_df: pd.DataFrame, batch_key: str | Sequence[str] + ) -> None: + if self.max_iter_harmony < 1: + msg = "max_iter_harmony must be >= 1" + raise ValueError(msg) + + # Process batch keys + self.batch_codes, self.n_batches = _get_batch_codes(batch_df, batch_key) + + def fit(self, x: np.ndarray) -> np.ndarray: + """Run Harmony. + + Returns + ------- + z_corr + Batch-corrected embedding matrix (n_cells x d). + """ + if self.random_state is not None: + np.random.seed(self.random_state) + + # Ensure input is C-contiguous float array (infer dtype from x) + x = np.ascontiguousarray(x) + n_cells = x.shape[0] + + # Normalize input for clustering + z_norm = _normalize_rows_l2(x) + + # Build phi matrix (one-hot encoding of batches) + if self.sparse: + phi = _one_hot_encode_sparse(self.batch_codes, self.n_batches, x.dtype) + n_b = np.asarray(phi.sum(axis=0)).ravel() + else: + phi = _one_hot_encode(self.batch_codes, self.n_batches, x.dtype) + n_b = phi.sum(axis=0) + pr_b = (n_b / n_cells).reshape(-1, 1) + + # Set default theta + if self.theta is None: + theta_arr = np.ones(self.n_batches, dtype=x.dtype) * 2.0 + elif isinstance(self.theta, (int, float)): + theta_arr = np.ones(self.n_batches, dtype=x.dtype) * float(self.theta) + else: + theta_arr = np.array(self.theta, dtype=x.dtype) + theta_arr = theta_arr.reshape(1, -1) - # Build phi matrix (one-hot encoding of batches) - if sparse: - phi = _one_hot_encode_sparse(batch_codes, n_batches, dtype) - n_b = np.asarray(phi.sum(axis=0)).ravel() - else: - phi = _one_hot_encode(batch_codes, n_batches, dtype) - n_b = phi.sum(axis=0) - pr_b = (n_b / n_cells).reshape(-1, 1) - - # Set default theta - if theta is None: - theta_arr = np.ones(n_batches, dtype=dtype) * 2.0 - elif isinstance(theta, (int, float)): - theta_arr = np.ones(n_batches, dtype=dtype) * float(theta) - else: - theta_arr = np.array(theta, dtype=dtype) - theta_arr = theta_arr.reshape(1, -1) - - # Set default n_clusters - if n_clusters is None: - n_clusters = int(min(100, n_cells / 30)) - n_clusters = max(n_clusters, 2) - - # Initialize centroids and state arrays - r, e, o, objectives_harmony = _initialize_centroids( - z_norm, - phi, - pr_b, - n_clusters=n_clusters, - sigma=sigma, - theta=theta_arr, - random_state=random_state, - ) + # Set default n_clusters + if self.n_clusters is None: + n_clusters = int(min(100, n_cells / 30)) + n_clusters = max(n_clusters, 2) + else: + n_clusters = self.n_clusters - # Main Harmony loop - converged = False - z_hat = x.copy() + # Initialize centroids and state arrays + r, e, o, obj_init = _initialize_centroids( + z_norm, + phi, + pr_b, + n_clusters=n_clusters, + sigma=self.sigma, + theta=theta_arr, + random_state=self.random_state, + ) - for i in tqdm(range(max_iter_harmony), disable=settings.verbosity < Verbosity.info): - # Clustering step - _clustering( + # Main Harmony loop + objectives_harmony = [obj_init] + with tqdm( + range(self.max_iter_harmony), disable=settings.verbosity < Verbosity.info + ) as bar: + for i in bar: + r, e, o, obj = self._cluster( + z_norm, pr_b, r=r, e=e, o=o, theta=theta_arr + ) + if obj is not None: + objectives_harmony.append(obj) + z_hat = self._correct(x, r, o) + z_norm = _normalize_rows_l2(z_hat) + if self._is_convergent(objectives_harmony, self.tol_harmony): + log.info(f"Harmony converged in {i + 1} iterations") + break + else: + log.info( + f"Harmony did not converge after {self.max_iter_harmony} iterations." + ) + + return z_hat + + @staticmethod + def _is_convergent(objectives: list[float], tol: float) -> bool: + """Check Harmony convergence.""" + if len(objectives) < 2: + return False + obj_old = objectives[-2] + obj_new = objectives[-1] + return (obj_old - obj_new) < tol * abs(obj_old) + + def _cluster( + self, + z_norm: np.ndarray, + pr_b: np.ndarray, + *, + r: np.ndarray, + e: np.ndarray, + o: np.ndarray, + theta: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, float | None]: + """Perform clustering step.""" + return _clustering( z_norm, - batch_codes, - n_batches, + self.batch_codes, + self.n_batches, pr_b, r=r, e=e, o=o, - theta=theta_arr, - sigma=sigma, - max_iter=max_iter_clustering, - tol=tol_clustering, - block_proportion=block_proportion, - objectives_harmony=objectives_harmony, + theta=theta, + sigma=self.sigma, + max_iter=self.max_iter_clustering, + tol=self.tol_clustering, + block_proportion=self.block_proportion, ) - # Correction step - if correction_method == "fast": - z_hat = _correction_fast( - x, batch_codes, n_batches, r, o, ridge_lambda=ridge_lambda + def _correct(self, x: np.ndarray, r: np.ndarray, o: np.ndarray) -> np.ndarray: + """Perform correction step.""" + if self.correction_method == "fast": + return _correction_fast( + x, + self.batch_codes, + self.n_batches, + r, + o, + ridge_lambda=self.ridge_lambda, ) else: - z_hat = _correction_original( - x, batch_codes, n_batches, r, ridge_lambda=ridge_lambda + return _correction_original( + x, + self.batch_codes, + self.n_batches, + r, + ridge_lambda=self.ridge_lambda, ) - # Normalize corrected data for next iteration - z_norm = _normalize_rows_l2(z_hat) - - # Check convergence - if _is_convergent_harmony(objectives_harmony, tol_harmony): - converged = True - log.info(f"Harmony converged in {i + 1} iterations") - break - - if not converged: - log.info(f"Harmony did not converge after {max_iter_harmony} iterations.") - - return z_hat - def _get_batch_codes( batch_df: pd.DataFrame, - batch_key: str | list[str], + batch_key: str | Sequence[str], ) -> tuple[np.ndarray, int]: """Get batch codes from DataFrame.""" if isinstance(batch_key, str): @@ -157,11 +207,11 @@ def _get_batch_codes( elif len(batch_key) == 1: batch_vec = batch_df[batch_key[0]] else: - df = batch_df[batch_key].astype("str") + df = batch_df[list(batch_key)].astype("str") batch_vec = df.apply(",".join, axis=1) batch_cat = batch_vec.astype("category") - codes = batch_cat.cat.codes.values.copy() + codes = batch_cat.cat.codes.to_numpy(copy=True) n_batches = len(batch_cat.cat.categories) return codes.astype(np.int32), n_batches @@ -208,14 +258,14 @@ def _normalize_rows_l1(r: np.ndarray) -> None: def _initialize_centroids( z_norm: np.ndarray, - phi: np.ndarray, + phi: np.ndarray | CSBase, pr_b: np.ndarray, *, n_clusters: int, sigma: float, theta: np.ndarray, random_state: int | None, -) -> tuple[np.ndarray, np.ndarray, np.ndarray, list]: +) -> tuple[np.ndarray, np.ndarray, np.ndarray, float]: """Initialize cluster centroids using K-means.""" kmeans = KMeans( n_clusters=n_clusters, random_state=random_state, n_init=10, max_iter=25 @@ -237,11 +287,9 @@ def _initialize_centroids( o = phi.T @ r # Compute initial objective - objectives_harmony: list = [] obj = _compute_objective(y_norm, z_norm, r, theta=theta, sigma=sigma, o=o, e=e) - objectives_harmony.append(obj) - return r, e, o, objectives_harmony + return r, e, o, obj def _compute_r( @@ -268,8 +316,7 @@ def _clustering( # noqa: PLR0913 max_iter: int, tol: float, block_proportion: float, - objectives_harmony: list, -) -> None: +) -> tuple[np.ndarray, np.ndarray, np.ndarray, float | None]: """Run clustering iterations (modifies r, e, o in-place).""" n_cells = z_norm.shape[0] k = r.shape[1] @@ -340,8 +387,12 @@ def _clustering( # noqa: PLR0913 # Check convergence if _is_convergent_clustering(objectives_clustering, tol): - objectives_harmony.append(objectives_clustering[-1]) + obj = objectives_clustering[-1] break + else: + obj = None + + return r, e, o, obj def _correction_original( @@ -469,20 +520,6 @@ def _compute_objective( return kmeans_error + entropy + diversity_penalty -def _is_convergent_harmony( - objectives: list, - tol: float, -) -> bool: - """Check Harmony convergence.""" - if len(objectives) < 2: - return False - - obj_old = objectives[-2] - obj_new = objectives[-1] - - return (obj_old - obj_new) < tol * abs(obj_old) - - def _is_convergent_clustering( objectives: list, tol: float, diff --git a/src/scanpy/preprocessing/_harmony_integrate.py b/src/scanpy/preprocessing/_harmony_integrate.py index ca4945195f..1d29c7b272 100644 --- a/src/scanpy/preprocessing/_harmony_integrate.py +++ b/src/scanpy/preprocessing/_harmony_integrate.py @@ -5,6 +5,7 @@ import numpy as np if TYPE_CHECKING: + from collections.abc import Sequence from typing import Literal from anndata import AnnData @@ -13,12 +14,12 @@ def harmony_integrate( # noqa: PLR0913 adata: AnnData, - key: str | list[str], + key: str | Sequence[str], *, basis: str = "X_pca", adjusted_basis: str = "X_pca_harmony", dtype: DTypeLike = np.float64, - theta: float | list[float] | None = None, + theta: float | Sequence[float] | None = None, sigma: float = 0.1, n_clusters: int | None = None, max_iter_harmony: int = 10, @@ -83,7 +84,7 @@ def harmony_integrate( # noqa: PLR0913 Updates adata with the field ``adata.obsm[adjusted_basis]``, containing principal components adjusted by Harmony. """ - from ._harmony import harmonize + from ._harmony import Harmony # Ensure the basis exists in adata.obsm if basis not in adata.obsm: @@ -115,8 +116,7 @@ def harmony_integrate( # noqa: PLR0913 raise ValueError(msg) # Run Harmony - harmony_out = harmonize( - x, + harmony = Harmony( adata.obs, key, theta=theta, @@ -132,6 +132,7 @@ def harmony_integrate( # noqa: PLR0913 random_state=random_state, sparse=sparse, ) + harmony_out = harmony.fit(x) # Store result adata.obsm[adjusted_basis] = harmony_out From 96fca7ab77439f28dda4592bf60e561daccea7da Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Feb 2026 16:08:02 +0100 Subject: [PATCH 09/14] more itertools --- src/scanpy/preprocessing/_harmony/__init__.py | 90 +++++++++---------- 1 file changed, 40 insertions(+), 50 deletions(-) diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py index ceb6818a76..5a5c0c6b3d 100644 --- a/src/scanpy/preprocessing/_harmony/__init__.py +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import KW_ONLY, InitVar, dataclass, field +from itertools import product from typing import TYPE_CHECKING import numpy as np @@ -160,7 +161,7 @@ def _cluster( o: np.ndarray, theta: np.ndarray, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, float | None]: - """Perform clustering step.""" + """Perform clustering step. Modifies r, e, o in-place.""" return _clustering( z_norm, self.batch_codes, @@ -320,7 +321,7 @@ def _clustering( # noqa: PLR0913 """Run clustering iterations (modifies r, e, o in-place).""" n_cells = z_norm.shape[0] k = r.shape[1] - block_size = max(1, int(n_cells * block_proportion)) + n_blocks = min(n_cells, 1 // block_proportion) term = -2.0 / sigma objectives_clustering = [] @@ -340,46 +341,41 @@ def _clustering( # noqa: PLR0913 idx_list = np.random.permutation(n_cells) # Process blocks - pos = 0 - while pos < n_cells: - end_pos = min(pos + block_size, n_cells) - block_idx = idx_list[pos:end_pos] - - for b in range(n_batches): - mask = batch_codes[block_idx] == b - if not np.any(mask): - continue - - cell_idx = block_idx[mask] - - # Remove old r contribution from o and e - r_old = r[cell_idx, :] - r_old_sum = r_old.sum(axis=0) - o[b, :] -= r_old_sum - e -= pr_b * r_old_sum - - # Compute new r values - dots = z_norm[cell_idx, :] @ y_norm.T - r_new = np.exp(term * (1.0 - dots)) - - # Apply penalty - penalty = ((e[b, :] + 1.0) / (o[b, :] + 1.0)) ** theta[0, b] - r_new *= penalty - - # Normalize rows to sum to 1 - row_sums = r_new.sum(axis=1, keepdims=True) - row_sums = np.maximum(row_sums, 1e-12) - r_new /= row_sums - - # Store back - r[cell_idx, :] = r_new - - # Add new r contribution to o and e - r_new_sum = r_new.sum(axis=0) - o[b, :] += r_new_sum - e += pr_b * r_new_sum - - pos = end_pos + for block_idx, b in product( + np.array_split(idx_list, n_blocks), range(n_batches) + ): + mask = batch_codes[block_idx] == b + if not np.any(mask): + continue + + cell_idx = block_idx[mask] + + # Remove old r contribution from o and e + r_old = r[cell_idx, :] + r_old_sum = r_old.sum(axis=0) + o[b, :] -= r_old_sum + e -= pr_b * r_old_sum + + # Compute new r values + dots = z_norm[cell_idx, :] @ y_norm.T + r_new = np.exp(term * (1.0 - dots)) + + # Apply penalty + penalty = ((e[b, :] + 1.0) / (o[b, :] + 1.0)) ** theta[0, b] + r_new *= penalty + + # Normalize rows to sum to 1 + row_sums = r_new.sum(axis=1, keepdims=True) + row_sums = np.maximum(row_sums, 1e-12) + r_new /= row_sums + + # Store back + r[cell_idx, :] = r_new + + # Add new r contribution to o and e + r_new_sum = r_new.sum(axis=0) + o[b, :] += r_new_sum + e += pr_b * r_new_sum # Compute objective obj = _compute_objective(y_norm, z_norm, r, theta=theta, sigma=sigma, o=o, e=e) @@ -405,7 +401,6 @@ def _correction_original( ) -> np.ndarray: """Original correction method - per-cluster ridge regression.""" _, d = x.shape - k = r.shape[1] # Ridge regularization matrix (don't penalize intercept) id_mat = np.eye(n_batches + 1) @@ -414,9 +409,7 @@ def _correction_original( z = x.copy() - for k_idx in range(k): - r_k = r[:, k_idx] - + for r_k in r.T: r_sum_total = r_k.sum() r_sum_per_batch = np.zeros(n_batches, dtype=x.dtype) for b in range(n_batches): @@ -458,13 +451,11 @@ def _correction_fast( ) -> np.ndarray: """Fast correction method using precomputed factors.""" _, d = x.shape - k = r.shape[1] z = x.copy() p = np.eye(n_batches + 1) - for k_idx in range(k): - o_k = o[:, k_idx] + for o_k, r_k in zip(o.T, r.T, strict=True): n_k = np.sum(o_k) factor = 1.0 / (o_k + ridge_lambda) @@ -480,7 +471,6 @@ def _correction_fast( inv_mat = p_t_b_inv @ p - r_k = r[:, k_idx] phi_t_x = np.zeros((n_batches + 1, d), dtype=x.dtype) phi_t_x[0, :] = r_k @ x for b in range(n_batches): From 4a542a3be70e51cecfd2e6c924b9b1109c48ec35 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 2 Mar 2026 10:44:34 +0100 Subject: [PATCH 10/14] fix test and remove phi --- src/scanpy/preprocessing/_harmony/__init__.py | 58 +++++-------------- tests/test_harmony.py | 2 +- 2 files changed, 16 insertions(+), 44 deletions(-) diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py index 5a5c0c6b3d..bfc61f80b1 100644 --- a/src/scanpy/preprocessing/_harmony/__init__.py +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING import numpy as np -from scipy.sparse import csr_matrix # noqa: TID251 from sklearn.cluster import KMeans from tqdm.auto import tqdm @@ -19,8 +18,6 @@ import pandas as pd - from ..._compat import CSBase - @dataclass class Harmony: @@ -83,13 +80,8 @@ def fit(self, x: np.ndarray) -> np.ndarray: # Normalize input for clustering z_norm = _normalize_rows_l2(x) - # Build phi matrix (one-hot encoding of batches) - if self.sparse: - phi = _one_hot_encode_sparse(self.batch_codes, self.n_batches, x.dtype) - n_b = np.asarray(phi.sum(axis=0)).ravel() - else: - phi = _one_hot_encode(self.batch_codes, self.n_batches, x.dtype) - n_b = phi.sum(axis=0) + # Compute batch proportions + n_b = np.bincount(self.batch_codes, minlength=self.n_batches).astype(x.dtype) pr_b = (n_b / n_cells).reshape(-1, 1) # Set default theta @@ -111,7 +103,8 @@ def fit(self, x: np.ndarray) -> np.ndarray: # Initialize centroids and state arrays r, e, o, obj_init = _initialize_centroids( z_norm, - phi, + self.batch_codes, + self.n_batches, pr_b, n_clusters=n_clusters, sigma=self.sigma, @@ -218,31 +211,6 @@ def _get_batch_codes( return codes.astype(np.int32), n_batches -def _one_hot_encode( - codes: np.ndarray, - n_categories: int, - dtype: np.dtype, -) -> np.ndarray: - """One-hot encode category codes.""" - n = len(codes) - phi = np.zeros((n, n_categories), dtype=dtype) - phi[np.arange(n), codes] = 1.0 - return phi - - -def _one_hot_encode_sparse( - codes: np.ndarray, - n_categories: int, - dtype: np.dtype, -): - """One-hot encode category codes as sparse CSR matrix.""" - n = len(codes) - data = np.ones(n, dtype=dtype) - indices = codes.astype(np.int32) - indptr = np.arange(n + 1, dtype=np.int32) - return csr_matrix((data, indices, indptr), shape=(n, n_categories)) - - def _normalize_rows_l2(x: np.ndarray) -> np.ndarray: """L2 normalize each row of x.""" norms = np.linalg.norm(x, axis=1, keepdims=True) @@ -259,7 +227,8 @@ def _normalize_rows_l1(r: np.ndarray) -> None: def _initialize_centroids( z_norm: np.ndarray, - phi: np.ndarray | CSBase, + batch_codes: np.ndarray, + n_batches: int, pr_b: np.ndarray, *, n_clusters: int, @@ -285,7 +254,9 @@ def _initialize_centroids( # Initialize e (expected) and o (observed) r_sum = r.sum(axis=0) e = pr_b @ r_sum.reshape(1, -1) - o = phi.T @ r + # o[b, k] = sum of r[i, k] for cells i in batch b + o = np.zeros((n_batches, n_clusters), dtype=z_norm.dtype) + np.add.at(o, batch_codes, r) # Compute initial objective obj = _compute_objective(y_norm, z_norm, r, theta=theta, sigma=sigma, o=o, e=e) @@ -403,7 +374,7 @@ def _correction_original( _, d = x.shape # Ridge regularization matrix (don't penalize intercept) - id_mat = np.eye(n_batches + 1) + id_mat = np.eye(n_batches + 1, dtype=x.dtype) id_mat[0, 0] = 0 lambda_mat = ridge_lambda * id_mat @@ -453,18 +424,19 @@ def _correction_fast( _, d = x.shape z = x.copy() - p = np.eye(n_batches + 1) + dtype = x.dtype + p = np.eye(n_batches + 1, dtype=dtype) for o_k, r_k in zip(o.T, r.T, strict=True): n_k = np.sum(o_k) - factor = 1.0 / (o_k + ridge_lambda) + factor = (1.0 / (o_k + ridge_lambda)).astype(dtype) c = n_k + np.sum(-factor * o_k**2) - c_inv = 1.0 / c + c_inv = dtype.type(1.0 / c) p[0, 1:] = -factor * o_k - p_t_b_inv = np.zeros((n_batches + 1, n_batches + 1)) + p_t_b_inv = np.zeros((n_batches + 1, n_batches + 1), dtype=dtype) p_t_b_inv[0, 0] = c_inv p_t_b_inv[1:, 1:] = np.diag(factor) p_t_b_inv[1:, 0] = p[0, 1:] * c_inv diff --git a/tests/test_harmony.py b/tests/test_harmony.py index 5d90fa146e..040550a131 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -120,7 +120,7 @@ def test_harmony_integrate_reference( with subtests.test("r"): assert _get_measure(x, base, "r").min() > 0.95 with subtests.test("L2"): - assert _get_measure(x, base, "L2").max() < 0.05 + assert _get_measure(x, base, "L2").max() < 0.1 def test_harmony_multiple_keys() -> None: From 88dca0955674c528c7605c35b2344a70b77cd728 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 2 Mar 2026 11:06:12 +0100 Subject: [PATCH 11/14] add tau remove sparse --- src/scanpy/preprocessing/_harmony/__init__.py | 8 ++++++-- src/scanpy/preprocessing/_harmony_integrate.py | 10 ++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py index bfc61f80b1..78d8829fe2 100644 --- a/src/scanpy/preprocessing/_harmony/__init__.py +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -46,8 +46,8 @@ class Harmony: ridge_lambda: float correction_method: Literal["fast", "original"] block_proportion: float + tau: int random_state: int | None - sparse: bool batch_codes: np.ndarray = field(init=False) n_batches: int = field(init=False) @@ -93,13 +93,17 @@ def fit(self, x: np.ndarray) -> np.ndarray: theta_arr = np.array(self.theta, dtype=x.dtype) theta_arr = theta_arr.reshape(1, -1) - # Set default n_clusters + # Set default n_clusters (needed before tau discounting) if self.n_clusters is None: n_clusters = int(min(100, n_cells / 30)) n_clusters = max(n_clusters, 2) else: n_clusters = self.n_clusters + # Apply tau discounting to theta + if self.tau > 0: + theta_arr = theta_arr * (1 - np.exp(-n_b / (n_clusters * self.tau)) ** 2) + # Initialize centroids and state arrays r, e, o, obj_init = _initialize_centroids( z_norm, diff --git a/src/scanpy/preprocessing/_harmony_integrate.py b/src/scanpy/preprocessing/_harmony_integrate.py index 1d29c7b272..a312335ab6 100644 --- a/src/scanpy/preprocessing/_harmony_integrate.py +++ b/src/scanpy/preprocessing/_harmony_integrate.py @@ -29,8 +29,8 @@ def harmony_integrate( # noqa: PLR0913 ridge_lambda: float = 1.0, correction_method: Literal["fast", "original"] = "original", block_proportion: float = 0.05, + tau: int = 0, random_state: int | None = 0, - sparse: bool = False, ) -> None: """ Integrate different experiments using the Harmony algorithm. @@ -74,10 +74,12 @@ def harmony_integrate( # noqa: PLR0913 original method, ``fast`` for improved method. block_proportion Fraction of cells processed per clustering iteration. Default 0.05. + tau + Discounting factor on ``theta``. Larger batches are penalized less + when ``tau > 0``: ``theta *= 1 - exp(-N_b / (n_clusters * tau))^2``. + By default (0), no discounting is applied. random_state Random seed for reproducibility. - sparse - Use sparse matrices for batch encoding. Reduces memory for large datasets. Returns ------- @@ -129,8 +131,8 @@ def harmony_integrate( # noqa: PLR0913 ridge_lambda=ridge_lambda, correction_method=correction_method, block_proportion=block_proportion, + tau=tau, random_state=random_state, - sparse=sparse, ) harmony_out = harmony.fit(x) From fa8e3cd276a1b8ca8188db05c8b7227c6e5f3857 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 8 Apr 2026 15:41:09 +0200 Subject: [PATCH 12/14] add harmony 2 and tests --- docs/release-notes/3953.feat.md | 1 + src/scanpy/preprocessing/_harmony/__init__.py | 143 ++++++++++-- .../preprocessing/_harmony_integrate.py | 144 +++++++++--- tests/test_harmony.py | 218 +++++++++++++++++- 4 files changed, 442 insertions(+), 64 deletions(-) create mode 100644 docs/release-notes/3953.feat.md diff --git a/docs/release-notes/3953.feat.md b/docs/release-notes/3953.feat.md new file mode 100644 index 0000000000..9af0ca3b7b --- /dev/null +++ b/docs/release-notes/3953.feat.md @@ -0,0 +1 @@ +Add {func}`scanpy.pp.harmony_integrate` with Harmony1 and Harmony2 support for batch correction {smaller}`S Dicks, P Angerer` diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py index 78d8829fe2..a24ebdd0b8 100644 --- a/src/scanpy/preprocessing/_harmony/__init__.py +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -18,6 +18,8 @@ import pandas as pd +_SUPPRESS_PENALTY = 1e30 + @dataclass class Harmony: @@ -36,7 +38,7 @@ class Harmony: batch_df: InitVar[pd.DataFrame] batch_key: InitVar[str | Sequence[str]] _: KW_ONLY - theta: float | Sequence[float] | None + theta: float | Sequence[float] sigma: float n_clusters: int | None max_iter_harmony: int @@ -48,6 +50,10 @@ class Harmony: block_proportion: float tau: int random_state: int | None + stabilized_penalty: bool = True + dynamic_lambda: bool = True + alpha: float = 0.2 + batch_prune_threshold: float | None = 1e-5 batch_codes: np.ndarray = field(init=False) n_batches: int = field(init=False) @@ -59,6 +65,16 @@ def __post_init__( msg = "max_iter_harmony must be >= 1" raise ValueError(msg) + if self.dynamic_lambda: + if not np.isfinite(self.alpha) or self.alpha <= 0: + msg = f"alpha must be a finite positive number when dynamic_lambda=True, got {self.alpha}." + raise ValueError(msg) + if self.batch_prune_threshold is not None and not ( + 0 <= self.batch_prune_threshold <= 1 + ): + msg = f"batch_prune_threshold must be in [0, 1] or None, got {self.batch_prune_threshold}." + raise ValueError(msg) + # Process batch keys self.batch_codes, self.n_batches = _get_batch_codes(batch_df, batch_key) @@ -85,9 +101,7 @@ def fit(self, x: np.ndarray) -> np.ndarray: pr_b = (n_b / n_cells).reshape(-1, 1) # Set default theta - if self.theta is None: - theta_arr = np.ones(self.n_batches, dtype=x.dtype) * 2.0 - elif isinstance(self.theta, (int, float)): + if isinstance(self.theta, (int, float)): theta_arr = np.ones(self.n_batches, dtype=x.dtype) * float(self.theta) else: theta_arr = np.array(self.theta, dtype=x.dtype) @@ -114,6 +128,7 @@ def fit(self, x: np.ndarray) -> np.ndarray: sigma=self.sigma, theta=theta_arr, random_state=self.random_state, + stabilized_penalty=self.stabilized_penalty, ) # Main Harmony loop @@ -127,7 +142,19 @@ def fit(self, x: np.ndarray) -> np.ndarray: ) if obj is not None: objectives_harmony.append(obj) - z_hat = self._correct(x, r, o) + + # Compute per-(k,b) ridge regularization + lambda_kb = _compute_lambda_kb( + e, + o=o, + n_b=n_b, + alpha=self.alpha, + threshold=self.batch_prune_threshold, + ridge_lambda=self.ridge_lambda, + dynamic_lambda=self.dynamic_lambda, + ) + + z_hat = self._correct(x, r, o, lambda_kb=lambda_kb) z_norm = _normalize_rows_l2(z_hat) if self._is_convergent(objectives_harmony, self.tol_harmony): log.info(f"Harmony converged in {i + 1} iterations") @@ -172,9 +199,17 @@ def _cluster( max_iter=self.max_iter_clustering, tol=self.tol_clustering, block_proportion=self.block_proportion, + stabilized_penalty=self.stabilized_penalty, ) - def _correct(self, x: np.ndarray, r: np.ndarray, o: np.ndarray) -> np.ndarray: + def _correct( + self, + x: np.ndarray, + r: np.ndarray, + o: np.ndarray, + *, + lambda_kb: np.ndarray, + ) -> np.ndarray: """Perform correction step.""" if self.correction_method == "fast": return _correction_fast( @@ -183,7 +218,7 @@ def _correct(self, x: np.ndarray, r: np.ndarray, o: np.ndarray) -> np.ndarray: self.n_batches, r, o, - ridge_lambda=self.ridge_lambda, + lambda_kb=lambda_kb, ) else: return _correction_original( @@ -191,7 +226,7 @@ def _correct(self, x: np.ndarray, r: np.ndarray, o: np.ndarray) -> np.ndarray: self.batch_codes, self.n_batches, r, - ridge_lambda=self.ridge_lambda, + lambda_kb=lambda_kb, ) @@ -239,6 +274,7 @@ def _initialize_centroids( sigma: float, theta: np.ndarray, random_state: int | None, + stabilized_penalty: bool = True, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, float]: """Initialize cluster centroids using K-means.""" kmeans = KMeans( @@ -263,7 +299,16 @@ def _initialize_centroids( np.add.at(o, batch_codes, r) # Compute initial objective - obj = _compute_objective(y_norm, z_norm, r, theta=theta, sigma=sigma, o=o, e=e) + obj = _compute_objective( + y_norm, + z_norm, + r, + theta=theta, + sigma=sigma, + o=o, + e=e, + stabilized_penalty=stabilized_penalty, + ) return r, e, o, obj @@ -292,6 +337,7 @@ def _clustering( # noqa: PLR0913 max_iter: int, tol: float, block_proportion: float, + stabilized_penalty: bool = True, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, float | None]: """Run clustering iterations (modifies r, e, o in-place).""" n_cells = z_norm.shape[0] @@ -335,8 +381,13 @@ def _clustering( # noqa: PLR0913 dots = z_norm[cell_idx, :] @ y_norm.T r_new = np.exp(term * (1.0 - dots)) - # Apply penalty - penalty = ((e[b, :] + 1.0) / (o[b, :] + 1.0)) ** theta[0, b] + # Apply penalty (Harmony1 vs Harmony2) + if stabilized_penalty: + # Harmony2: denominator is (O + E + 1) + penalty = ((e[b, :] + 1.0) / (o[b, :] + e[b, :] + 1.0)) ** theta[0, b] + else: + # Harmony1: denominator is (O + 1) + penalty = ((e[b, :] + 1.0) / (o[b, :] + 1.0)) ** theta[0, b] r_new *= penalty # Normalize rows to sum to 1 @@ -353,7 +404,16 @@ def _clustering( # noqa: PLR0913 e += pr_b * r_new_sum # Compute objective - obj = _compute_objective(y_norm, z_norm, r, theta=theta, sigma=sigma, o=o, e=e) + obj = _compute_objective( + y_norm, + z_norm, + r, + theta=theta, + sigma=sigma, + o=o, + e=e, + stabilized_penalty=stabilized_penalty, + ) objectives_clustering.append(obj) # Check convergence @@ -366,25 +426,52 @@ def _clustering( # noqa: PLR0913 return r, e, o, obj +def _compute_lambda_kb( + e: np.ndarray, + *, + o: np.ndarray, + n_b: np.ndarray, + alpha: float, + threshold: float | None, + ridge_lambda: float, + dynamic_lambda: bool, +) -> np.ndarray: + """Compute per-(k,b) ridge regularization array.""" + sentinel = e.dtype.type(_SUPPRESS_PENALTY) + if not dynamic_lambda: + lambda_kb = np.full_like(e, ridge_lambda) + else: + lambda_kb = (alpha * e).astype(e.dtype) + if threshold is not None: + safe_n_b = np.where(n_b > 0, n_b, np.ones_like(n_b)) + prune_mask = (o / safe_n_b[:, None]) < threshold + prune_mask |= n_b[:, None] == 0 + lambda_kb[prune_mask] = sentinel + # Where both O and lambda_kb are zero, the kernel computes 1/(O+lambda) + # which would divide by zero. + lambda_kb[(o + lambda_kb) == 0] = sentinel + return lambda_kb + + def _correction_original( x: np.ndarray, batch_codes: np.ndarray, n_batches: int, r: np.ndarray, *, - ridge_lambda: float, + lambda_kb: np.ndarray, ) -> np.ndarray: """Original correction method - per-cluster ridge regression.""" _, d = x.shape - # Ridge regularization matrix (don't penalize intercept) - id_mat = np.eye(n_batches + 1, dtype=x.dtype) - id_mat[0, 0] = 0 - lambda_mat = ridge_lambda * id_mat - z = x.copy() - for r_k in r.T: + for k_idx, r_k in enumerate(r.T): + # Build per-cluster lambda diagonal + lambda_diag = np.zeros(n_batches + 1, dtype=x.dtype) + lambda_diag[1:] = lambda_kb[:, k_idx] + lambda_mat = np.diag(lambda_diag) + r_sum_total = r_k.sum() r_sum_per_batch = np.zeros(n_batches, dtype=x.dtype) for b in range(n_batches): @@ -422,7 +509,7 @@ def _correction_fast( r: np.ndarray, o: np.ndarray, *, - ridge_lambda: float, + lambda_kb: np.ndarray, ) -> np.ndarray: """Fast correction method using precomputed factors.""" _, d = x.shape @@ -431,11 +518,11 @@ def _correction_fast( dtype = x.dtype p = np.eye(n_batches + 1, dtype=dtype) - for o_k, r_k in zip(o.T, r.T, strict=True): - n_k = np.sum(o_k) + for k_idx, (o_k, r_k) in enumerate(zip(o.T, r.T, strict=True)): + lam_k = lambda_kb[:, k_idx] - factor = (1.0 / (o_k + ridge_lambda)).astype(dtype) - c = n_k + np.sum(-factor * o_k**2) + factor = (1.0 / (o_k + lam_k)).astype(dtype) + c = np.sum(o_k) + np.sum(-factor * o_k**2) c_inv = dtype.type(1.0 / c) p[0, 1:] = -factor * o_k @@ -471,6 +558,7 @@ def _compute_objective( sigma: float, o: np.ndarray, e: np.ndarray, + stabilized_penalty: bool = True, ) -> float: """Compute Harmony objective function.""" zy = z_norm @ y_norm.T @@ -480,7 +568,12 @@ def _compute_objective( r_normalized = r / np.clip(r_row_sums, 1e-12, None) entropy = sigma * np.sum(r_normalized * np.log(r_normalized + 1e-12)) - log_ratio = np.log((o + 1) / (e + 1)) + if stabilized_penalty: + # Harmony2: numerator is (O + E + 1) + log_ratio = np.log((o + e + 1) / (e + 1)) + else: + # Harmony1: numerator is (O + 1) + log_ratio = np.log((o + 1) / (e + 1)) diversity_penalty = sigma * np.sum(theta @ (o * log_ratio)) return kmeans_error + entropy + diversity_penalty diff --git a/src/scanpy/preprocessing/_harmony_integrate.py b/src/scanpy/preprocessing/_harmony_integrate.py index a312335ab6..474ee47909 100644 --- a/src/scanpy/preprocessing/_harmony_integrate.py +++ b/src/scanpy/preprocessing/_harmony_integrate.py @@ -4,6 +4,8 @@ import numpy as np +from .._compat import warn + if TYPE_CHECKING: from collections.abc import Sequence from typing import Literal @@ -19,24 +21,34 @@ def harmony_integrate( # noqa: PLR0913 basis: str = "X_pca", adjusted_basis: str = "X_pca_harmony", dtype: DTypeLike = np.float64, - theta: float | Sequence[float] | None = None, - sigma: float = 0.1, + flavor: Literal["harmony2", "harmony1"] = "harmony2", n_clusters: int | None = None, max_iter_harmony: int = 10, max_iter_clustering: int = 200, tol_harmony: float = 1e-4, tol_clustering: float = 1e-5, + sigma: float = 0.1, + theta: float | Sequence[float] = 2.0, + tau: int = 0, ridge_lambda: float = 1.0, + alpha: float = 0.2, + batch_prune_threshold: float | None = 1e-5, correction_method: Literal["fast", "original"] = "original", block_proportion: float = 0.05, - tau: int = 0, random_state: int | None = 0, ) -> None: - """ - Integrate different experiments using the Harmony algorithm. + """Integrate different experiments using the Harmony algorithm :cite:p:`Korsunsky2019,Patikas2026`. + + This CPU implementation is based on the harmony-pytorch & rapids_singlecell + version, using NumPy for efficient computation. As Harmony works by adjusting + the principal components, this function should be run after performing PCA but + before computing the neighbor graph. - This CPU implementation is based on the harmony-pytorch & rapids_singlecell version, - using NumPy for efficient computation. + By default, the Harmony2 algorithm is used, which includes a stabilized + diversity penalty, dynamic per-cluster-per-batch ridge regularization, + and automatic batch pruning. To revert to the original Harmony behavior:: + + sc.pp.harmony_integrate(adata, key, flavor="harmony1") Parameters ---------- @@ -44,7 +56,8 @@ def harmony_integrate( # noqa: PLR0913 The annotated data matrix. key The key(s) of the column(s) in ``adata.obs`` that differentiates - among experiments/batches. + among experiments/batches. When multiple keys are provided, a + combined batch variable is created from all columns. basis The name of the field in ``adata.obsm`` where the PCA table is stored. Defaults to ``'X_pca'``. @@ -52,42 +65,113 @@ def harmony_integrate( # noqa: PLR0913 The name of the field in ``adata.obsm`` where the adjusted PCA table will be stored. Defaults to ``X_pca_harmony``. dtype - The data type to use for Harmony computation. - theta - Diversity penalty weight(s). Default is 2 for each batch variable. - sigma - Width of soft clustering kernel. Default 0.1. + The data type to use for Harmony computation. If you use 32-bit + you may experience numerical instability. + flavor + Which version of the Harmony algorithm to use. + ``"harmony2"`` (default) enables the stabilized diversity penalty, + dynamic per-cluster-per-batch ridge regularization, and automatic + batch pruning from :cite:p:`Patikas2026`. + ``"harmony1"`` uses the original algorithm from + :cite:p:`Korsunsky2019`. n_clusters - Number of clusters. Default is min(100, n_cells/30). + Number of clusters used for soft k-means in the Harmony algorithm. + If ``None``, uses ``min(100, N / 30)``. More clusters capture + finer-grained structure but increase computation time. max_iter_harmony - Maximum Harmony iterations. Default 10. + Maximum number of outer Harmony iterations (each consisting of + a clustering step followed by a correction step). max_iter_clustering - Maximum clustering iterations per Harmony round. Default 200. + Maximum iterations for the clustering step within each Harmony + iteration. tol_harmony - Convergence tolerance for Harmony. Default 1e-4. + Convergence tolerance for the Harmony objective function. + The algorithm stops when the relative change in objective falls + below this value. tol_clustering - Convergence tolerance for clustering. Default 1e-5. + Convergence tolerance for the clustering step within each + Harmony iteration. + sigma + Width of the soft-clustering kernel. Controls the entropy of + cluster assignments: smaller values produce harder assignments + (cells assigned to fewer clusters), while larger values produce + softer assignments (cells spread across more clusters). + theta + Diversity penalty weight per batch variable. Controls how + strongly Harmony encourages each cluster to contain a balanced + representation of all batches. Higher values (e.g. ``4``) + produce more aggressive mixing; lower values (e.g. ``0.5``) + allow more batch-specific clusters. Set to ``0`` to disable + batch correction entirely. A list can be provided to set + different weights per batch variable. + tau + Discounting factor on ``theta``. When ``tau > 0``, the + diversity penalty is down-weighted for batches with fewer cells, + preventing over-correction of small batches. By default (``0``), + there is no discounting. ridge_lambda - Ridge regression regularization. Default 1.0. + Ridge regression regularization for the correction step. + Larger values produce more conservative (smaller) corrections, + preventing over-fitting. Only used with ``flavor="harmony1"``. + alpha + Scaling factor for the dynamic per-cluster-per-batch ridge + regularization. The effective regularization for each + cluster-batch pair is ``alpha * E_kb`` where ``E_kb`` is the + expected number of cells. Larger values produce more + conservative corrections. Only used with ``flavor="harmony2"``. + batch_prune_threshold + Fraction threshold below which a batch-cluster pair is pruned + (correction suppressed). When the fraction of a batch's cells + assigned to a cluster (``O_kb / N_b``) falls below this + threshold, that batch-cluster pair receives no correction, + preventing spurious adjustments. Only used with + ``flavor="harmony2"``. Set to ``None`` to disable pruning. correction_method - Choose which method for the correction step: ``original`` for - original method, ``fast`` for improved method. + Method for the correction step. ``"original"`` uses per-cluster + ridge regression with explicit matrix inversion. ``"fast"`` + uses a precomputed factorization that avoids the full inversion, + which can be faster for datasets with many batches. block_proportion - Fraction of cells processed per clustering iteration. Default 0.05. - tau - Discounting factor on ``theta``. Larger batches are penalized less - when ``tau > 0``: ``theta *= 1 - exp(-N_b / (n_clusters * tau))^2``. - By default (0), no discounting is applied. + Proportion of cells updated per clustering sub-iteration. + Smaller values produce more stochastic updates. Larger values + are faster but may converge to different solutions. random_state Random seed for reproducibility. Returns ------- - Updates adata with the field ``adata.obsm[adjusted_basis]``, - containing principal components adjusted by Harmony. + Updates adata with the field ``adata.obsm[adjusted_basis]``, \ + containing principal components adjusted by Harmony such that \ + different experiments are integrated. """ from ._harmony import Harmony + # Resolve flavor into internal flags + if flavor not in {"harmony1", "harmony2"}: + msg = f"flavor must be 'harmony1' or 'harmony2', got {flavor!r}." + raise ValueError(msg) + stabilized_penalty = flavor == "harmony2" + dynamic_lambda = flavor == "harmony2" + + # Warn when flavor-incompatible parameters are explicitly set + if flavor == "harmony2" and ridge_lambda != 1.0: + warn( + "ridge_lambda is ignored when flavor='harmony2'; " + "use alpha to control regularization strength.", + UserWarning, + ) + if flavor == "harmony1": + if alpha != 0.2: + warn( + "alpha is ignored when flavor='harmony1'; use ridge_lambda instead.", + UserWarning, + ) + if batch_prune_threshold != 1e-5: + warn( + "batch_prune_threshold is ignored when flavor='harmony1'.", + UserWarning, + ) + # Ensure the basis exists in adata.obsm if basis not in adata.obsm: msg = ( @@ -133,6 +217,10 @@ def harmony_integrate( # noqa: PLR0913 block_proportion=block_proportion, tau=tau, random_state=random_state, + stabilized_penalty=stabilized_penalty, + dynamic_lambda=dynamic_lambda, + alpha=alpha, + batch_prune_threshold=batch_prune_threshold, ) harmony_out = harmony.fit(x) diff --git a/tests/test_harmony.py b/tests/test_harmony.py index 040550a131..857cc3f155 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -10,6 +10,7 @@ from scipy.stats import pearsonr from scanpy.preprocessing import harmony_integrate +from scanpy.preprocessing._harmony import _SUPPRESS_PENALTY, _compute_lambda_kb from testing.scanpy._helpers.data import pbmc68k_reduced if TYPE_CHECKING: @@ -18,6 +19,8 @@ from numpy.typing import DTypeLike +_HARMONY_DATA_BASE = "https://exampledata.scverse.org/rapids-singlecell/harmony_data" + DATA = dict( pca=("pbmc_3500_pcs.tsv.gz", "md5:27e319b3ddcc0c00d98e70aa8e677b10"), pca_harmonized=( @@ -48,12 +51,12 @@ def _get_measure( pytest.fail(f"Unknown {norm=!r}") -@pytest.fixture -def adata_reference() -> AnnData: - """Load reference data from harmonypy repository.""" +@pytest.fixture(scope="module") +def _adata_reference() -> AnnData: + """Load reference data once per module (avoids re-reading CSV).""" paths = { f: pooch.retrieve( - f"https://github.com/slowkow/harmonypy/raw/refs/heads/master/data/{name}", + f"{_HARMONY_DATA_BASE}/{name}", known_hash=hash_, ) for f, (name, hash_) in DATA.items() @@ -61,23 +64,35 @@ def adata_reference() -> AnnData: dfs = {f: pd.read_csv(path, delimiter="\t") for f, path in paths.items()} # Create unique index using row number + cell name dfs["meta"].index = [f"{i}_{cell}" for i, cell in enumerate(dfs["meta"]["cell"])] - adata = AnnData( + return AnnData( X=None, obs=dfs["meta"], obsm={"X_pca": dfs["pca"].values, "harmony_org": dfs["pca_harmonized"].values}, ) - return adata + + +@pytest.fixture +def adata_reference(_adata_reference: AnnData) -> AnnData: + """Return a fresh copy per test so tests don't mutate shared state.""" + return _adata_reference.copy() @pytest.mark.parametrize("correction_method", ["fast", "original"]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("flavor", ["harmony1", "harmony2"]) def test_harmony_integrate( - correction_method: Literal["fast", "original"], dtype: DTypeLike + correction_method: Literal["fast", "original"], + dtype: DTypeLike, + flavor: Literal["harmony1", "harmony2"], ) -> None: """Test that Harmony integrate works.""" adata = pbmc68k_reduced() harmony_integrate( - adata, "bulk_labels", correction_method=correction_method, dtype=dtype + adata, + "bulk_labels", + correction_method=correction_method, + dtype=dtype, + flavor=flavor, ) assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape @@ -87,9 +102,17 @@ def test_harmony_integrate_algos(subtests: pytest.Subtests, dtype: DTypeLike) -> """Test that both correction methods produce similar results.""" adata = pbmc68k_reduced() - harmony_integrate(adata, "bulk_labels", correction_method="fast", dtype=dtype) + harmony_integrate( + adata, "bulk_labels", correction_method="fast", dtype=dtype, flavor="harmony1" + ) fast = adata.obsm["X_pca_harmony"].copy() - harmony_integrate(adata, "bulk_labels", correction_method="original", dtype=dtype) + harmony_integrate( + adata, + "bulk_labels", + correction_method="original", + dtype=dtype, + flavor="harmony1", + ) slow = adata.obsm["X_pca_harmony"].copy() with subtests.test("r"): @@ -107,13 +130,14 @@ def test_harmony_integrate_reference( dtype: DTypeLike, correction_method: Literal["fast", "original"], ) -> None: - """Test that Harmony produces results similar to the reference implementation.""" + """Test that Harmony1 produces results similar to the reference implementation.""" harmony_integrate( adata_reference, "donor", correction_method=correction_method, dtype=dtype, max_iter_harmony=20, + flavor="harmony1", ) x, base = adata_reference.obsm["harmony_org"], adata_reference.obsm["X_pca_harmony"] @@ -123,6 +147,42 @@ def test_harmony_integrate_reference( assert _get_measure(x, base, "L2").max() < 0.1 +@pytest.mark.parametrize("correction_method", ["fast", "original"]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_harmony2_correction_methods_agree( + subtests: pytest.Subtests, + adata_reference: AnnData, + correction_method: Literal["fast", "original"], + dtype: DTypeLike, +) -> None: + """Harmony2 default path: correction methods produce consistent results.""" + harmony_integrate( + adata_reference, + "donor", + correction_method=correction_method, + dtype=dtype, + max_iter_harmony=20, + ) + h2 = adata_reference.obsm["X_pca_harmony"] + + # Run the other method for comparison + other = "original" if correction_method == "fast" else "fast" + adata_ref2 = adata_reference.copy() + harmony_integrate( + adata_ref2, + "donor", + correction_method=other, + dtype=dtype, + max_iter_harmony=20, + ) + h2_ref = adata_ref2.obsm["X_pca_harmony"] + + with subtests.test("r"): + assert _get_measure(h2, h2_ref, "r").min() > 0.99 + with subtests.test("L2"): + assert _get_measure(h2, h2_ref, "L2").max() < 0.05 + + def test_harmony_multiple_keys() -> None: """Test Harmony with multiple batch keys.""" adata = pbmc68k_reduced() @@ -145,6 +205,7 @@ def test_harmony_custom_parameters() -> None: n_clusters=50, max_iter_harmony=5, ridge_lambda=0.5, + flavor="harmony1", ) assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape @@ -164,3 +225,138 @@ def test_harmony_input_validation(subtests) -> None: harmony_integrate(adata, "bulk_labels", basis="nonexistent") with subtests.test("no key"), pytest.raises(KeyError): harmony_integrate(adata, "nonexistent_key") + + +def test_harmony_invalid_flavor() -> None: + """Test that invalid flavor raises ValueError.""" + adata = pbmc68k_reduced() + with pytest.raises(ValueError, match="flavor must be"): + harmony_integrate(adata, "bulk_labels", flavor="harmony3") + + +@pytest.mark.parametrize("bad_alpha", [-0.1, 0.0, float("inf"), float("nan")]) +def test_harmony_integrate_bad_alpha(bad_alpha: float) -> None: + """Non-positive or non-finite alpha with flavor='harmony2' raises ValueError.""" + adata = pbmc68k_reduced() + with pytest.raises(ValueError, match="alpha must be a finite positive"): + harmony_integrate(adata, "bulk_labels", alpha=bad_alpha) + + +@pytest.mark.parametrize("bad_threshold", [-0.1, 1.5, 2.0]) +def test_harmony_integrate_bad_prune_threshold(bad_threshold: float) -> None: + """batch_prune_threshold outside [0, 1] raises ValueError.""" + adata = pbmc68k_reduced() + with pytest.raises(ValueError, match="batch_prune_threshold must be in"): + harmony_integrate(adata, "bulk_labels", batch_prune_threshold=bad_threshold) + + +def test_harmony_flavor_warnings() -> None: + """Test that flavor-incompatible parameter warnings are raised.""" + adata = pbmc68k_reduced() + + # harmony2 with ridge_lambda should warn + with pytest.warns(UserWarning, match="ridge_lambda is ignored"): + harmony_integrate( + adata, + "bulk_labels", + flavor="harmony2", + ridge_lambda=2.0, + max_iter_harmony=1, + ) + + # harmony1 with alpha should warn + with pytest.warns(UserWarning, match="alpha is ignored"): + harmony_integrate( + adata, "bulk_labels", flavor="harmony1", alpha=0.5, max_iter_harmony=1 + ) + + # harmony1 with batch_prune_threshold should warn + with pytest.warns(UserWarning, match="batch_prune_threshold is ignored"): + harmony_integrate( + adata, + "bulk_labels", + flavor="harmony1", + batch_prune_threshold=0.01, + max_iter_harmony=1, + ) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_compute_lambda_kb_pruning(dtype: DTypeLike) -> None: + """_compute_lambda_kb suppresses correction for N_b==0 and below-threshold pairs.""" + n_batches, n_clusters = 4, 3 + alpha = 0.2 + threshold = 1e-5 + sentinel = dtype(_SUPPRESS_PENALTY) + + n_b = np.array([0, 100, 1, 50], dtype=dtype) + o = np.array( + [[0, 0, 0], [30, 40, 30], [0, 0, 1], [20, 15, 15]], + dtype=dtype, + ) + e = np.ones((n_batches, n_clusters), dtype=dtype) * 10 + + result = _compute_lambda_kb( + e, + o=o, + n_b=n_b, + alpha=alpha, + threshold=threshold, + ridge_lambda=1.0, + dynamic_lambda=True, + ) + + # batch 0 (N_b==0): all clusters must be sentinel + assert np.all(result[0] == sentinel) + # batch 1 (well-represented): should be alpha * E = 2.0 + np.testing.assert_allclose(result[1], np.full(n_clusters, alpha * 10, dtype=dtype)) + # batch 2, clusters 0,1 (O/N_b = 0/1 < threshold): sentinel + assert result[2, 0] == sentinel + assert result[2, 1] == sentinel + # batch 2, cluster 2 (O/N_b = 1/1 = 1.0 >= threshold): alpha * E + np.testing.assert_allclose(result[2, 2], dtype(alpha * 10)) + # batch 3: all alpha * E + np.testing.assert_allclose(result[3], np.full(n_clusters, alpha * 10, dtype=dtype)) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_compute_lambda_kb_dynamic_false(dtype: DTypeLike) -> None: + """_compute_lambda_kb returns uniform ridge_lambda when dynamic_lambda=False.""" + n_batches, n_clusters = 3, 5 + e = np.ones((n_batches, n_clusters), dtype=dtype) + o = np.ones((n_batches, n_clusters), dtype=dtype) + n_b = np.ones(n_batches, dtype=dtype) + + result = _compute_lambda_kb( + e, + o=o, + n_b=n_b, + alpha=0.5, + threshold=1e-5, + ridge_lambda=1.0, + dynamic_lambda=False, + ) + np.testing.assert_array_equal(result, np.full_like(e, 1.0)) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_compute_lambda_kb_zero_denom(dtype: DTypeLike) -> None: + """_compute_lambda_kb guards against O==0 and E==0 (zero-denominator).""" + sentinel = dtype(_SUPPRESS_PENALTY) + e = np.array([[0.0, 5.0]], dtype=dtype) + o = np.array([[0.0, 10.0]], dtype=dtype) + n_b = np.array([100.0], dtype=dtype) + + result = _compute_lambda_kb( + e, + o=o, + n_b=n_b, + alpha=0.2, + threshold=None, + ridge_lambda=1.0, + dynamic_lambda=True, + ) + # (0,0): O+lambda_kb = 0+0 = 0 -> sentinel + assert result[0, 0] == sentinel + # (0,1): normal -> alpha * E = 1.0 + np.testing.assert_allclose(result[0, 1], dtype(1.0)) From fad0fd33140deb9d555f13aa6c6eb5ff3e1442fd Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 20 Apr 2026 15:59:43 +0200 Subject: [PATCH 13/14] restructure --- src/scanpy/preprocessing/__init__.py | 2 +- src/scanpy/preprocessing/_harmony/__init__.py | 766 +++++------------- src/scanpy/preprocessing/_harmony/core.py | 598 ++++++++++++++ .../preprocessing/_harmony_integrate.py | 228 ------ tests/test_harmony.py | 8 +- 5 files changed, 803 insertions(+), 799 deletions(-) create mode 100644 src/scanpy/preprocessing/_harmony/core.py delete mode 100644 src/scanpy/preprocessing/_harmony_integrate.py diff --git a/src/scanpy/preprocessing/__init__.py b/src/scanpy/preprocessing/__init__.py index 3be1da6ec6..cb1aedb90c 100644 --- a/src/scanpy/preprocessing/__init__.py +++ b/src/scanpy/preprocessing/__init__.py @@ -5,7 +5,7 @@ from ..neighbors import neighbors from ._combat import combat from ._deprecated.sampling import subsample -from ._harmony_integrate import harmony_integrate +from ._harmony import harmony_integrate from ._highly_variable_genes import highly_variable_genes from ._normalization import normalize_total from ._pca import pca diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py index a24ebdd0b8..091c84ba51 100644 --- a/src/scanpy/preprocessing/_harmony/__init__.py +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -1,594 +1,228 @@ from __future__ import annotations -from dataclasses import KW_ONLY, InitVar, dataclass, field -from itertools import product from typing import TYPE_CHECKING import numpy as np -from sklearn.cluster import KMeans -from tqdm.auto import tqdm -from ... import logging as log -from ..._settings import settings -from ..._settings.verbosity import Verbosity +from ..._compat import warn if TYPE_CHECKING: from collections.abc import Sequence from typing import Literal - import pandas as pd + from anndata import AnnData + from numpy.typing import DTypeLike -_SUPPRESS_PENALTY = 1e30 - -@dataclass -class Harmony: - """Harmony batch correction algorithm. +def harmony_integrate( # noqa: PLR0913 + adata: AnnData, + key: str | Sequence[str], + *, + basis: str = "X_pca", + adjusted_basis: str = "X_pca_harmony", + dtype: DTypeLike = np.float64, + flavor: Literal["harmony2", "harmony1"] = "harmony2", + n_clusters: int | None = None, + max_iter_harmony: int = 10, + max_iter_clustering: int = 200, + tol_harmony: float = 1e-4, + tol_clustering: float = 1e-5, + sigma: float = 0.1, + theta: float | Sequence[float] = 2.0, + tau: int = 0, + ridge_lambda: float = 1.0, + alpha: float = 0.2, + batch_prune_threshold: float | None = 1e-5, + correction_method: Literal["fast", "original"] = "original", + block_proportion: float = 0.05, + random_state: int | None = 0, +) -> None: + """Integrate different experiments using the Harmony algorithm :cite:p:`Korsunsky2019,Patikas2026`. + + This CPU implementation is based on the harmony-pytorch & rapids_singlecell + version, using NumPy for efficient computation. As Harmony works by adjusting + the principal components, this function should be run after performing PCA but + before computing the neighbor graph. + + By default, the Harmony2 algorithm is used, which includes a stabilized + diversity penalty, dynamic per-cluster-per-batch ridge regularization, + and automatic batch pruning. To revert to the original Harmony behavior:: + + sc.pp.harmony_integrate(adata, key, flavor="harmony1") Parameters ---------- - x - Data matrix (n_cells x d) - typically PCA embeddings. - batch_df - DataFrame containing batch information. - batch_key - Column name(s) in batch_df containing batch labels. + adata + The annotated data matrix. + key + The key(s) of the column(s) in ``adata.obs`` that differentiates + among experiments/batches. When multiple keys are provided, a + combined batch variable is created from all columns. + basis + The name of the field in ``adata.obsm`` where the PCA table is + stored. Defaults to ``'X_pca'``. + adjusted_basis + The name of the field in ``adata.obsm`` where the adjusted PCA + table will be stored. Defaults to ``X_pca_harmony``. + dtype + The data type to use for Harmony computation. If you use 32-bit + you may experience numerical instability. + flavor + Which version of the Harmony algorithm to use. + ``"harmony2"`` (default) enables the stabilized diversity penalty, + dynamic per-cluster-per-batch ridge regularization, and automatic + batch pruning from :cite:p:`Patikas2026`. + ``"harmony1"`` uses the original algorithm from + :cite:p:`Korsunsky2019`. + n_clusters + Number of clusters used for soft k-means in the Harmony algorithm. + If ``None``, uses ``min(100, N / 30)``. More clusters capture + finer-grained structure but increase computation time. + max_iter_harmony + Maximum number of outer Harmony iterations (each consisting of + a clustering step followed by a correction step). + max_iter_clustering + Maximum iterations for the clustering step within each Harmony + iteration. + tol_harmony + Convergence tolerance for the Harmony objective function. + The algorithm stops when the relative change in objective falls + below this value. + tol_clustering + Convergence tolerance for the clustering step within each + Harmony iteration. + sigma + Width of the soft-clustering kernel. Controls the entropy of + cluster assignments: smaller values produce harder assignments + (cells assigned to fewer clusters), while larger values produce + softer assignments (cells spread across more clusters). + theta + Diversity penalty weight per batch variable. Controls how + strongly Harmony encourages each cluster to contain a balanced + representation of all batches. Higher values (e.g. ``4``) + produce more aggressive mixing; lower values (e.g. ``0.5``) + allow more batch-specific clusters. Set to ``0`` to disable + batch correction entirely. A list can be provided to set + different weights per batch variable. + tau + Discounting factor on ``theta``. When ``tau > 0``, the + diversity penalty is down-weighted for batches with fewer cells, + preventing over-correction of small batches. By default (``0``), + there is no discounting. + ridge_lambda + Ridge regression regularization for the correction step. + Larger values produce more conservative (smaller) corrections, + preventing over-fitting. Only used with ``flavor="harmony1"``. + alpha + Scaling factor for the dynamic per-cluster-per-batch ridge + regularization. The effective regularization for each + cluster-batch pair is ``alpha * E_kb`` where ``E_kb`` is the + expected number of cells. Larger values produce more + conservative corrections. Only used with ``flavor="harmony2"``. + batch_prune_threshold + Fraction threshold below which a batch-cluster pair is pruned + (correction suppressed). When the fraction of a batch's cells + assigned to a cluster (``O_kb / N_b``) falls below this + threshold, that batch-cluster pair receives no correction, + preventing spurious adjustments. Only used with + ``flavor="harmony2"``. Set to ``None`` to disable pruning. + correction_method + Method for the correction step. ``"original"`` uses per-cluster + ridge regression with explicit matrix inversion. ``"fast"`` + uses a precomputed factorization that avoids the full inversion, + which can be faster for datasets with many batches. + block_proportion + Proportion of cells updated per clustering sub-iteration. + Smaller values produce more stochastic updates. Larger values + are faster but may converge to different solutions. + random_state + Random seed for reproducibility. + + Returns + ------- + Updates adata with the field ``adata.obsm[adjusted_basis]``, \ + containing principal components adjusted by Harmony such that \ + different experiments are integrated. """ - - batch_df: InitVar[pd.DataFrame] - batch_key: InitVar[str | Sequence[str]] - _: KW_ONLY - theta: float | Sequence[float] - sigma: float - n_clusters: int | None - max_iter_harmony: int - max_iter_clustering: int - tol_harmony: float - tol_clustering: float - ridge_lambda: float - correction_method: Literal["fast", "original"] - block_proportion: float - tau: int - random_state: int | None - stabilized_penalty: bool = True - dynamic_lambda: bool = True - alpha: float = 0.2 - batch_prune_threshold: float | None = 1e-5 - - batch_codes: np.ndarray = field(init=False) - n_batches: int = field(init=False) - - def __post_init__( - self, batch_df: pd.DataFrame, batch_key: str | Sequence[str] - ) -> None: - if self.max_iter_harmony < 1: - msg = "max_iter_harmony must be >= 1" - raise ValueError(msg) - - if self.dynamic_lambda: - if not np.isfinite(self.alpha) or self.alpha <= 0: - msg = f"alpha must be a finite positive number when dynamic_lambda=True, got {self.alpha}." - raise ValueError(msg) - if self.batch_prune_threshold is not None and not ( - 0 <= self.batch_prune_threshold <= 1 - ): - msg = f"batch_prune_threshold must be in [0, 1] or None, got {self.batch_prune_threshold}." - raise ValueError(msg) - - # Process batch keys - self.batch_codes, self.n_batches = _get_batch_codes(batch_df, batch_key) - - def fit(self, x: np.ndarray) -> np.ndarray: - """Run Harmony. - - Returns - ------- - z_corr - Batch-corrected embedding matrix (n_cells x d). - """ - if self.random_state is not None: - np.random.seed(self.random_state) - - # Ensure input is C-contiguous float array (infer dtype from x) - x = np.ascontiguousarray(x) - n_cells = x.shape[0] - - # Normalize input for clustering - z_norm = _normalize_rows_l2(x) - - # Compute batch proportions - n_b = np.bincount(self.batch_codes, minlength=self.n_batches).astype(x.dtype) - pr_b = (n_b / n_cells).reshape(-1, 1) - - # Set default theta - if isinstance(self.theta, (int, float)): - theta_arr = np.ones(self.n_batches, dtype=x.dtype) * float(self.theta) - else: - theta_arr = np.array(self.theta, dtype=x.dtype) - theta_arr = theta_arr.reshape(1, -1) - - # Set default n_clusters (needed before tau discounting) - if self.n_clusters is None: - n_clusters = int(min(100, n_cells / 30)) - n_clusters = max(n_clusters, 2) - else: - n_clusters = self.n_clusters - - # Apply tau discounting to theta - if self.tau > 0: - theta_arr = theta_arr * (1 - np.exp(-n_b / (n_clusters * self.tau)) ** 2) - - # Initialize centroids and state arrays - r, e, o, obj_init = _initialize_centroids( - z_norm, - self.batch_codes, - self.n_batches, - pr_b, - n_clusters=n_clusters, - sigma=self.sigma, - theta=theta_arr, - random_state=self.random_state, - stabilized_penalty=self.stabilized_penalty, - ) - - # Main Harmony loop - objectives_harmony = [obj_init] - with tqdm( - range(self.max_iter_harmony), disable=settings.verbosity < Verbosity.info - ) as bar: - for i in bar: - r, e, o, obj = self._cluster( - z_norm, pr_b, r=r, e=e, o=o, theta=theta_arr - ) - if obj is not None: - objectives_harmony.append(obj) - - # Compute per-(k,b) ridge regularization - lambda_kb = _compute_lambda_kb( - e, - o=o, - n_b=n_b, - alpha=self.alpha, - threshold=self.batch_prune_threshold, - ridge_lambda=self.ridge_lambda, - dynamic_lambda=self.dynamic_lambda, - ) - - z_hat = self._correct(x, r, o, lambda_kb=lambda_kb) - z_norm = _normalize_rows_l2(z_hat) - if self._is_convergent(objectives_harmony, self.tol_harmony): - log.info(f"Harmony converged in {i + 1} iterations") - break - else: - log.info( - f"Harmony did not converge after {self.max_iter_harmony} iterations." - ) - - return z_hat - - @staticmethod - def _is_convergent(objectives: list[float], tol: float) -> bool: - """Check Harmony convergence.""" - if len(objectives) < 2: - return False - obj_old = objectives[-2] - obj_new = objectives[-1] - return (obj_old - obj_new) < tol * abs(obj_old) - - def _cluster( - self, - z_norm: np.ndarray, - pr_b: np.ndarray, - *, - r: np.ndarray, - e: np.ndarray, - o: np.ndarray, - theta: np.ndarray, - ) -> tuple[np.ndarray, np.ndarray, np.ndarray, float | None]: - """Perform clustering step. Modifies r, e, o in-place.""" - return _clustering( - z_norm, - self.batch_codes, - self.n_batches, - pr_b, - r=r, - e=e, - o=o, - theta=theta, - sigma=self.sigma, - max_iter=self.max_iter_clustering, - tol=self.tol_clustering, - block_proportion=self.block_proportion, - stabilized_penalty=self.stabilized_penalty, + from .core import Harmony + + # Resolve flavor into internal flags + if flavor not in {"harmony1", "harmony2"}: + msg = f"flavor must be 'harmony1' or 'harmony2', got {flavor!r}." + raise ValueError(msg) + stabilized_penalty = flavor == "harmony2" + dynamic_lambda = flavor == "harmony2" + + # Warn when flavor-incompatible parameters are explicitly set + if flavor == "harmony2" and ridge_lambda != 1.0: + warn( + "ridge_lambda is ignored when flavor='harmony2'; " + "use alpha to control regularization strength.", + UserWarning, ) - - def _correct( - self, - x: np.ndarray, - r: np.ndarray, - o: np.ndarray, - *, - lambda_kb: np.ndarray, - ) -> np.ndarray: - """Perform correction step.""" - if self.correction_method == "fast": - return _correction_fast( - x, - self.batch_codes, - self.n_batches, - r, - o, - lambda_kb=lambda_kb, + if flavor == "harmony1": + if alpha != 0.2: + warn( + "alpha is ignored when flavor='harmony1'; use ridge_lambda instead.", + UserWarning, ) - else: - return _correction_original( - x, - self.batch_codes, - self.n_batches, - r, - lambda_kb=lambda_kb, + if batch_prune_threshold != 1e-5: + warn( + "batch_prune_threshold is ignored when flavor='harmony1'.", + UserWarning, ) + # Ensure the basis exists in adata.obsm + if basis not in adata.obsm: + msg = ( + f"The specified basis {basis!r} is not available in `adata.obsm`. " + f"Available bases: {list(adata.obsm.keys())}" + ) + raise ValueError(msg) + + # Get the input data + input_data = adata.obsm[basis] + + # Convert to numpy array with specified dtype + try: + x = np.ascontiguousarray(input_data, dtype=dtype) + except Exception as e: + msg = ( + f"Could not convert input of type {type(input_data).__name__} " + "to NumPy array." + ) + raise TypeError(msg) from e -def _get_batch_codes( - batch_df: pd.DataFrame, - batch_key: str | Sequence[str], -) -> tuple[np.ndarray, int]: - """Get batch codes from DataFrame.""" - if isinstance(batch_key, str): - batch_vec = batch_df[batch_key] - elif len(batch_key) == 1: - batch_vec = batch_df[batch_key[0]] - else: - df = batch_df[list(batch_key)].astype("str") - batch_vec = df.apply(",".join, axis=1) - - batch_cat = batch_vec.astype("category") - codes = batch_cat.cat.codes.to_numpy(copy=True) - n_batches = len(batch_cat.cat.categories) - - return codes.astype(np.int32), n_batches - - -def _normalize_rows_l2(x: np.ndarray) -> np.ndarray: - """L2 normalize each row of x.""" - norms = np.linalg.norm(x, axis=1, keepdims=True) - norms = np.maximum(norms, 1e-12) - return x / norms - - -def _normalize_rows_l1(r: np.ndarray) -> None: - """L1 normalize each row of r in-place (rows sum to 1).""" - row_sums = r.sum(axis=1, keepdims=True) - row_sums = np.maximum(row_sums, 1e-12) - r /= row_sums - - -def _initialize_centroids( - z_norm: np.ndarray, - batch_codes: np.ndarray, - n_batches: int, - pr_b: np.ndarray, - *, - n_clusters: int, - sigma: float, - theta: np.ndarray, - random_state: int | None, - stabilized_penalty: bool = True, -) -> tuple[np.ndarray, np.ndarray, np.ndarray, float]: - """Initialize cluster centroids using K-means.""" - kmeans = KMeans( - n_clusters=n_clusters, random_state=random_state, n_init=10, max_iter=25 - ) - kmeans.fit(z_norm) - - # Centroids - y = kmeans.cluster_centers_.copy() - y_norm = _normalize_rows_l2(y) - - # Compute soft cluster assignments r - term = -2.0 / sigma - r = _compute_r(z_norm, y_norm, term) - _normalize_rows_l1(r) - - # Initialize e (expected) and o (observed) - r_sum = r.sum(axis=0) - e = pr_b @ r_sum.reshape(1, -1) - # o[b, k] = sum of r[i, k] for cells i in batch b - o = np.zeros((n_batches, n_clusters), dtype=z_norm.dtype) - np.add.at(o, batch_codes, r) + # Check for NaN values + if np.isnan(x).any(): + msg = ( + "Input data contains NaN values. Please handle these before " + "running harmony_integrate." + ) + raise ValueError(msg) - # Compute initial objective - obj = _compute_objective( - y_norm, - z_norm, - r, + # Run Harmony + harmony = Harmony( + adata.obs, + key, theta=theta, sigma=sigma, - o=o, - e=e, + n_clusters=n_clusters, + max_iter_harmony=max_iter_harmony, + max_iter_clustering=max_iter_clustering, + tol_harmony=tol_harmony, + tol_clustering=tol_clustering, + ridge_lambda=ridge_lambda, + correction_method=correction_method, + block_proportion=block_proportion, + tau=tau, + random_state=random_state, stabilized_penalty=stabilized_penalty, + dynamic_lambda=dynamic_lambda, + alpha=alpha, + batch_prune_threshold=batch_prune_threshold, ) + harmony_out = harmony.fit(x) - return r, e, o, obj - - -def _compute_r( - z: np.ndarray, - y: np.ndarray, - term: float, -) -> np.ndarray: - """Compute soft cluster assignments using NumPy dot.""" - dots = z @ y.T - return np.exp(term * (1.0 - dots)) - - -def _clustering( # noqa: PLR0913 - z_norm: np.ndarray, - batch_codes: np.ndarray, - n_batches: int, - pr_b: np.ndarray, - *, - r: np.ndarray, - e: np.ndarray, - o: np.ndarray, - theta: np.ndarray, - sigma: float, - max_iter: int, - tol: float, - block_proportion: float, - stabilized_penalty: bool = True, -) -> tuple[np.ndarray, np.ndarray, np.ndarray, float | None]: - """Run clustering iterations (modifies r, e, o in-place).""" - n_cells = z_norm.shape[0] - k = r.shape[1] - n_blocks = min(n_cells, 1 // block_proportion) - term = -2.0 / sigma - - objectives_clustering = [] - - # Pre-allocate work arrays - y = np.empty((k, z_norm.shape[1]), dtype=z_norm.dtype) - y_norm = np.empty_like(y) - - for _ in range(max_iter): - # Compute cluster centroids: y = r.T @ z_norm, then normalize - np.dot(r.T, z_norm, out=y) - norms = np.linalg.norm(y, axis=1, keepdims=True) - norms = np.maximum(norms, 1e-12) - np.divide(y, norms, out=y_norm) - - # Randomly shuffle cell indices - idx_list = np.random.permutation(n_cells) - - # Process blocks - for block_idx, b in product( - np.array_split(idx_list, n_blocks), range(n_batches) - ): - mask = batch_codes[block_idx] == b - if not np.any(mask): - continue - - cell_idx = block_idx[mask] - - # Remove old r contribution from o and e - r_old = r[cell_idx, :] - r_old_sum = r_old.sum(axis=0) - o[b, :] -= r_old_sum - e -= pr_b * r_old_sum - - # Compute new r values - dots = z_norm[cell_idx, :] @ y_norm.T - r_new = np.exp(term * (1.0 - dots)) - - # Apply penalty (Harmony1 vs Harmony2) - if stabilized_penalty: - # Harmony2: denominator is (O + E + 1) - penalty = ((e[b, :] + 1.0) / (o[b, :] + e[b, :] + 1.0)) ** theta[0, b] - else: - # Harmony1: denominator is (O + 1) - penalty = ((e[b, :] + 1.0) / (o[b, :] + 1.0)) ** theta[0, b] - r_new *= penalty - - # Normalize rows to sum to 1 - row_sums = r_new.sum(axis=1, keepdims=True) - row_sums = np.maximum(row_sums, 1e-12) - r_new /= row_sums - - # Store back - r[cell_idx, :] = r_new - - # Add new r contribution to o and e - r_new_sum = r_new.sum(axis=0) - o[b, :] += r_new_sum - e += pr_b * r_new_sum - - # Compute objective - obj = _compute_objective( - y_norm, - z_norm, - r, - theta=theta, - sigma=sigma, - o=o, - e=e, - stabilized_penalty=stabilized_penalty, - ) - objectives_clustering.append(obj) - - # Check convergence - if _is_convergent_clustering(objectives_clustering, tol): - obj = objectives_clustering[-1] - break - else: - obj = None - - return r, e, o, obj - - -def _compute_lambda_kb( - e: np.ndarray, - *, - o: np.ndarray, - n_b: np.ndarray, - alpha: float, - threshold: float | None, - ridge_lambda: float, - dynamic_lambda: bool, -) -> np.ndarray: - """Compute per-(k,b) ridge regularization array.""" - sentinel = e.dtype.type(_SUPPRESS_PENALTY) - if not dynamic_lambda: - lambda_kb = np.full_like(e, ridge_lambda) - else: - lambda_kb = (alpha * e).astype(e.dtype) - if threshold is not None: - safe_n_b = np.where(n_b > 0, n_b, np.ones_like(n_b)) - prune_mask = (o / safe_n_b[:, None]) < threshold - prune_mask |= n_b[:, None] == 0 - lambda_kb[prune_mask] = sentinel - # Where both O and lambda_kb are zero, the kernel computes 1/(O+lambda) - # which would divide by zero. - lambda_kb[(o + lambda_kb) == 0] = sentinel - return lambda_kb - - -def _correction_original( - x: np.ndarray, - batch_codes: np.ndarray, - n_batches: int, - r: np.ndarray, - *, - lambda_kb: np.ndarray, -) -> np.ndarray: - """Original correction method - per-cluster ridge regression.""" - _, d = x.shape - - z = x.copy() - - for k_idx, r_k in enumerate(r.T): - # Build per-cluster lambda diagonal - lambda_diag = np.zeros(n_batches + 1, dtype=x.dtype) - lambda_diag[1:] = lambda_kb[:, k_idx] - lambda_mat = np.diag(lambda_diag) - - r_sum_total = r_k.sum() - r_sum_per_batch = np.zeros(n_batches, dtype=x.dtype) - for b in range(n_batches): - r_sum_per_batch[b] = r_k[batch_codes == b].sum() - - phi_t_phi = np.zeros((n_batches + 1, n_batches + 1), dtype=x.dtype) - phi_t_phi[0, 0] = r_sum_total - phi_t_phi[0, 1:] = r_sum_per_batch - phi_t_phi[1:, 0] = r_sum_per_batch - phi_t_phi[1:, 1:] = np.diag(r_sum_per_batch) - phi_t_phi += lambda_mat - - phi_t_x = np.zeros((n_batches + 1, d), dtype=x.dtype) - phi_t_x[0, :] = r_k @ x - for b in range(n_batches): - mask = batch_codes == b - phi_t_x[b + 1, :] = r_k[mask] @ x[mask] - - try: - w = np.linalg.solve(phi_t_phi, phi_t_x) - except np.linalg.LinAlgError: - w = np.linalg.lstsq(phi_t_phi, phi_t_x, rcond=None)[0] - - w[0, :] = 0 - w_batch = w[batch_codes + 1, :] - z -= r_k[:, np.newaxis] * w_batch - - return z - - -def _correction_fast( - x: np.ndarray, - batch_codes: np.ndarray, - n_batches: int, - r: np.ndarray, - o: np.ndarray, - *, - lambda_kb: np.ndarray, -) -> np.ndarray: - """Fast correction method using precomputed factors.""" - _, d = x.shape - - z = x.copy() - dtype = x.dtype - p = np.eye(n_batches + 1, dtype=dtype) - - for k_idx, (o_k, r_k) in enumerate(zip(o.T, r.T, strict=True)): - lam_k = lambda_kb[:, k_idx] - - factor = (1.0 / (o_k + lam_k)).astype(dtype) - c = np.sum(o_k) + np.sum(-factor * o_k**2) - c_inv = dtype.type(1.0 / c) - - p[0, 1:] = -factor * o_k - - p_t_b_inv = np.zeros((n_batches + 1, n_batches + 1), dtype=dtype) - p_t_b_inv[0, 0] = c_inv - p_t_b_inv[1:, 1:] = np.diag(factor) - p_t_b_inv[1:, 0] = p[0, 1:] * c_inv - - inv_mat = p_t_b_inv @ p - - phi_t_x = np.zeros((n_batches + 1, d), dtype=x.dtype) - phi_t_x[0, :] = r_k @ x - for b in range(n_batches): - mask = batch_codes == b - phi_t_x[b + 1, :] = r_k[mask] @ x[mask] - - w = inv_mat @ phi_t_x - w[0, :] = 0 - - w_batch = w[batch_codes + 1, :] - z -= r_k[:, np.newaxis] * w_batch - - return z - - -def _compute_objective( - y_norm: np.ndarray, - z_norm: np.ndarray, - r: np.ndarray, - *, - theta: np.ndarray, - sigma: float, - o: np.ndarray, - e: np.ndarray, - stabilized_penalty: bool = True, -) -> float: - """Compute Harmony objective function.""" - zy = z_norm @ y_norm.T - kmeans_error = np.sum(r * 2.0 * (1.0 - zy)) - - r_row_sums = r.sum(axis=1, keepdims=True) - r_normalized = r / np.clip(r_row_sums, 1e-12, None) - entropy = sigma * np.sum(r_normalized * np.log(r_normalized + 1e-12)) - - if stabilized_penalty: - # Harmony2: numerator is (O + E + 1) - log_ratio = np.log((o + e + 1) / (e + 1)) - else: - # Harmony1: numerator is (O + 1) - log_ratio = np.log((o + 1) / (e + 1)) - diversity_penalty = sigma * np.sum(theta @ (o * log_ratio)) - - return kmeans_error + entropy + diversity_penalty - - -def _is_convergent_clustering( - objectives: list, - tol: float, - window_size: int = 3, -) -> bool: - """Check clustering convergence using window.""" - if len(objectives) < window_size + 1: - return False - - obj_old = sum(objectives[-window_size - 1 : -1]) - obj_new = sum(objectives[-window_size:]) - - return (obj_old - obj_new) < tol * abs(obj_old) + # Store result + adata.obsm[adjusted_basis] = harmony_out diff --git a/src/scanpy/preprocessing/_harmony/core.py b/src/scanpy/preprocessing/_harmony/core.py new file mode 100644 index 0000000000..cf48ec9369 --- /dev/null +++ b/src/scanpy/preprocessing/_harmony/core.py @@ -0,0 +1,598 @@ +from __future__ import annotations + +from dataclasses import KW_ONLY, InitVar, dataclass, field +from itertools import product +from typing import TYPE_CHECKING + +import numpy as np +from sklearn.cluster import KMeans +from tqdm.auto import tqdm + +from ... import logging as log +from ..._settings import settings +from ..._settings.verbosity import Verbosity + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Literal + + import pandas as pd + + +__all__ = ["_SUPPRESS_PENALTY", "Harmony", "_compute_lambda_kb"] + + +_SUPPRESS_PENALTY = 1e30 + + +@dataclass +class Harmony: + """Harmony batch correction algorithm. + + Parameters + ---------- + x + Data matrix (n_cells x d) - typically PCA embeddings. + batch_df + DataFrame containing batch information. + batch_key + Column name(s) in batch_df containing batch labels. + """ + + batch_df: InitVar[pd.DataFrame] + batch_key: InitVar[str | Sequence[str]] + _: KW_ONLY + theta: float | Sequence[float] + sigma: float + n_clusters: int | None + max_iter_harmony: int + max_iter_clustering: int + tol_harmony: float + tol_clustering: float + ridge_lambda: float + correction_method: Literal["fast", "original"] + block_proportion: float + tau: int + random_state: int | None + stabilized_penalty: bool = True + dynamic_lambda: bool = True + alpha: float = 0.2 + batch_prune_threshold: float | None = 1e-5 + + batch_codes: np.ndarray = field(init=False) + n_batches: int = field(init=False) + + def __post_init__( + self, batch_df: pd.DataFrame, batch_key: str | Sequence[str] + ) -> None: + if self.max_iter_harmony < 1: + msg = "max_iter_harmony must be >= 1" + raise ValueError(msg) + + if self.dynamic_lambda: + if not np.isfinite(self.alpha) or self.alpha <= 0: + msg = f"alpha must be a finite positive number when dynamic_lambda=True, got {self.alpha}." + raise ValueError(msg) + if self.batch_prune_threshold is not None and not ( + 0 <= self.batch_prune_threshold <= 1 + ): + msg = f"batch_prune_threshold must be in [0, 1] or None, got {self.batch_prune_threshold}." + raise ValueError(msg) + + # Process batch keys + self.batch_codes, self.n_batches = _get_batch_codes(batch_df, batch_key) + + def fit(self, x: np.ndarray) -> np.ndarray: + """Run Harmony. + + Returns + ------- + z_corr + Batch-corrected embedding matrix (n_cells x d). + """ + if self.random_state is not None: + np.random.seed(self.random_state) + + # Ensure input is C-contiguous float array (infer dtype from x) + x = np.ascontiguousarray(x) + n_cells = x.shape[0] + + # Normalize input for clustering + z_norm = _normalize_rows_l2(x) + + # Compute batch proportions + n_b = np.bincount(self.batch_codes, minlength=self.n_batches).astype(x.dtype) + pr_b = (n_b / n_cells).reshape(-1, 1) + + # Set default theta + if isinstance(self.theta, (int, float)): + theta_arr = np.ones(self.n_batches, dtype=x.dtype) * float(self.theta) + else: + theta_arr = np.array(self.theta, dtype=x.dtype) + theta_arr = theta_arr.reshape(1, -1) + + # Set default n_clusters (needed before tau discounting) + if self.n_clusters is None: + n_clusters = int(min(100, n_cells / 30)) + n_clusters = max(n_clusters, 2) + else: + n_clusters = self.n_clusters + + # Apply tau discounting to theta + if self.tau > 0: + theta_arr = theta_arr * (1 - np.exp(-n_b / (n_clusters * self.tau)) ** 2) + + # Initialize centroids and state arrays + r, e, o, obj_init = _initialize_centroids( + z_norm, + self.batch_codes, + self.n_batches, + pr_b, + n_clusters=n_clusters, + sigma=self.sigma, + theta=theta_arr, + random_state=self.random_state, + stabilized_penalty=self.stabilized_penalty, + ) + + # Main Harmony loop + objectives_harmony = [obj_init] + with tqdm( + range(self.max_iter_harmony), disable=settings.verbosity < Verbosity.info + ) as bar: + for i in bar: + r, e, o, obj = self._cluster( + z_norm, pr_b, r=r, e=e, o=o, theta=theta_arr + ) + if obj is not None: + objectives_harmony.append(obj) + + # Compute per-(k,b) ridge regularization + lambda_kb = _compute_lambda_kb( + e, + o=o, + n_b=n_b, + alpha=self.alpha, + threshold=self.batch_prune_threshold, + ridge_lambda=self.ridge_lambda, + dynamic_lambda=self.dynamic_lambda, + ) + + z_hat = self._correct(x, r, o, lambda_kb=lambda_kb) + z_norm = _normalize_rows_l2(z_hat) + if self._is_convergent(objectives_harmony, self.tol_harmony): + log.info(f"Harmony converged in {i + 1} iterations") + break + else: + log.info( + f"Harmony did not converge after {self.max_iter_harmony} iterations." + ) + + return z_hat + + @staticmethod + def _is_convergent(objectives: list[float], tol: float) -> bool: + """Check Harmony convergence.""" + if len(objectives) < 2: + return False + obj_old = objectives[-2] + obj_new = objectives[-1] + return (obj_old - obj_new) < tol * abs(obj_old) + + def _cluster( + self, + z_norm: np.ndarray, + pr_b: np.ndarray, + *, + r: np.ndarray, + e: np.ndarray, + o: np.ndarray, + theta: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, float | None]: + """Perform clustering step. Modifies r, e, o in-place.""" + return _clustering( + z_norm, + self.batch_codes, + self.n_batches, + pr_b, + r=r, + e=e, + o=o, + theta=theta, + sigma=self.sigma, + max_iter=self.max_iter_clustering, + tol=self.tol_clustering, + block_proportion=self.block_proportion, + stabilized_penalty=self.stabilized_penalty, + ) + + def _correct( + self, + x: np.ndarray, + r: np.ndarray, + o: np.ndarray, + *, + lambda_kb: np.ndarray, + ) -> np.ndarray: + """Perform correction step.""" + if self.correction_method == "fast": + return _correction_fast( + x, + self.batch_codes, + self.n_batches, + r, + o, + lambda_kb=lambda_kb, + ) + else: + return _correction_original( + x, + self.batch_codes, + self.n_batches, + r, + lambda_kb=lambda_kb, + ) + + +def _get_batch_codes( + batch_df: pd.DataFrame, + batch_key: str | Sequence[str], +) -> tuple[np.ndarray, int]: + """Get batch codes from DataFrame.""" + if isinstance(batch_key, str): + batch_vec = batch_df[batch_key] + elif len(batch_key) == 1: + batch_vec = batch_df[batch_key[0]] + else: + df = batch_df[list(batch_key)].astype("str") + batch_vec = df.apply(",".join, axis=1) + + batch_cat = batch_vec.astype("category") + codes = batch_cat.cat.codes.to_numpy(copy=True) + n_batches = len(batch_cat.cat.categories) + + return codes.astype(np.int32), n_batches + + +def _normalize_rows_l2(x: np.ndarray) -> np.ndarray: + """L2 normalize each row of x.""" + norms = np.linalg.norm(x, axis=1, keepdims=True) + norms = np.maximum(norms, 1e-12) + return x / norms + + +def _normalize_rows_l1(r: np.ndarray) -> None: + """L1 normalize each row of r in-place (rows sum to 1).""" + row_sums = r.sum(axis=1, keepdims=True) + row_sums = np.maximum(row_sums, 1e-12) + r /= row_sums + + +def _initialize_centroids( + z_norm: np.ndarray, + batch_codes: np.ndarray, + n_batches: int, + pr_b: np.ndarray, + *, + n_clusters: int, + sigma: float, + theta: np.ndarray, + random_state: int | None, + stabilized_penalty: bool = True, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, float]: + """Initialize cluster centroids using K-means.""" + kmeans = KMeans( + n_clusters=n_clusters, random_state=random_state, n_init=10, max_iter=25 + ) + kmeans.fit(z_norm) + + # Centroids + y = kmeans.cluster_centers_.copy() + y_norm = _normalize_rows_l2(y) + + # Compute soft cluster assignments r + term = -2.0 / sigma + r = _compute_r(z_norm, y_norm, term) + _normalize_rows_l1(r) + + # Initialize e (expected) and o (observed) + r_sum = r.sum(axis=0) + e = pr_b @ r_sum.reshape(1, -1) + # o[b, k] = sum of r[i, k] for cells i in batch b + o = np.zeros((n_batches, n_clusters), dtype=z_norm.dtype) + np.add.at(o, batch_codes, r) + + # Compute initial objective + obj = _compute_objective( + y_norm, + z_norm, + r, + theta=theta, + sigma=sigma, + o=o, + e=e, + stabilized_penalty=stabilized_penalty, + ) + + return r, e, o, obj + + +def _compute_r( + z: np.ndarray, + y: np.ndarray, + term: float, +) -> np.ndarray: + """Compute soft cluster assignments using NumPy dot.""" + dots = z @ y.T + return np.exp(term * (1.0 - dots)) + + +def _clustering( # noqa: PLR0913 + z_norm: np.ndarray, + batch_codes: np.ndarray, + n_batches: int, + pr_b: np.ndarray, + *, + r: np.ndarray, + e: np.ndarray, + o: np.ndarray, + theta: np.ndarray, + sigma: float, + max_iter: int, + tol: float, + block_proportion: float, + stabilized_penalty: bool = True, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, float | None]: + """Run clustering iterations (modifies r, e, o in-place).""" + n_cells = z_norm.shape[0] + k = r.shape[1] + n_blocks = min(n_cells, 1 // block_proportion) + term = -2.0 / sigma + + objectives_clustering = [] + + # Pre-allocate work arrays + y = np.empty((k, z_norm.shape[1]), dtype=z_norm.dtype) + y_norm = np.empty_like(y) + + for _ in range(max_iter): + # Compute cluster centroids: y = r.T @ z_norm, then normalize + np.dot(r.T, z_norm, out=y) + norms = np.linalg.norm(y, axis=1, keepdims=True) + norms = np.maximum(norms, 1e-12) + np.divide(y, norms, out=y_norm) + + # Randomly shuffle cell indices + idx_list = np.random.permutation(n_cells) + + # Process blocks + for block_idx, b in product( + np.array_split(idx_list, n_blocks), range(n_batches) + ): + mask = batch_codes[block_idx] == b + if not np.any(mask): + continue + + cell_idx = block_idx[mask] + + # Remove old r contribution from o and e + r_old = r[cell_idx, :] + r_old_sum = r_old.sum(axis=0) + o[b, :] -= r_old_sum + e -= pr_b * r_old_sum + + # Compute new r values + dots = z_norm[cell_idx, :] @ y_norm.T + r_new = np.exp(term * (1.0 - dots)) + + # Apply penalty (Harmony1 vs Harmony2) + if stabilized_penalty: + # Harmony2: denominator is (O + E + 1) + penalty = ((e[b, :] + 1.0) / (o[b, :] + e[b, :] + 1.0)) ** theta[0, b] + else: + # Harmony1: denominator is (O + 1) + penalty = ((e[b, :] + 1.0) / (o[b, :] + 1.0)) ** theta[0, b] + r_new *= penalty + + # Normalize rows to sum to 1 + row_sums = r_new.sum(axis=1, keepdims=True) + row_sums = np.maximum(row_sums, 1e-12) + r_new /= row_sums + + # Store back + r[cell_idx, :] = r_new + + # Add new r contribution to o and e + r_new_sum = r_new.sum(axis=0) + o[b, :] += r_new_sum + e += pr_b * r_new_sum + + # Compute objective + obj = _compute_objective( + y_norm, + z_norm, + r, + theta=theta, + sigma=sigma, + o=o, + e=e, + stabilized_penalty=stabilized_penalty, + ) + objectives_clustering.append(obj) + + # Check convergence + if _is_convergent_clustering(objectives_clustering, tol): + obj = objectives_clustering[-1] + break + else: + obj = None + + return r, e, o, obj + + +def _compute_lambda_kb( + e: np.ndarray, + *, + o: np.ndarray, + n_b: np.ndarray, + alpha: float, + threshold: float | None, + ridge_lambda: float, + dynamic_lambda: bool, +) -> np.ndarray: + """Compute per-(k,b) ridge regularization array.""" + sentinel = e.dtype.type(_SUPPRESS_PENALTY) + if not dynamic_lambda: + lambda_kb = np.full_like(e, ridge_lambda) + else: + lambda_kb = (alpha * e).astype(e.dtype) + if threshold is not None: + safe_n_b = np.where(n_b > 0, n_b, np.ones_like(n_b)) + prune_mask = (o / safe_n_b[:, None]) < threshold + prune_mask |= n_b[:, None] == 0 + lambda_kb[prune_mask] = sentinel + # Where both O and lambda_kb are zero, the kernel computes 1/(O+lambda) + # which would divide by zero. + lambda_kb[(o + lambda_kb) == 0] = sentinel + return lambda_kb + + +def _correction_original( + x: np.ndarray, + batch_codes: np.ndarray, + n_batches: int, + r: np.ndarray, + *, + lambda_kb: np.ndarray, +) -> np.ndarray: + """Original correction method - per-cluster ridge regression.""" + _, d = x.shape + + z = x.copy() + + for k_idx, r_k in enumerate(r.T): + # Build per-cluster lambda diagonal + lambda_diag = np.zeros(n_batches + 1, dtype=x.dtype) + lambda_diag[1:] = lambda_kb[:, k_idx] + lambda_mat = np.diag(lambda_diag) + + r_sum_total = r_k.sum() + r_sum_per_batch = np.zeros(n_batches, dtype=x.dtype) + for b in range(n_batches): + r_sum_per_batch[b] = r_k[batch_codes == b].sum() + + phi_t_phi = np.zeros((n_batches + 1, n_batches + 1), dtype=x.dtype) + phi_t_phi[0, 0] = r_sum_total + phi_t_phi[0, 1:] = r_sum_per_batch + phi_t_phi[1:, 0] = r_sum_per_batch + phi_t_phi[1:, 1:] = np.diag(r_sum_per_batch) + phi_t_phi += lambda_mat + + phi_t_x = np.zeros((n_batches + 1, d), dtype=x.dtype) + phi_t_x[0, :] = r_k @ x + for b in range(n_batches): + mask = batch_codes == b + phi_t_x[b + 1, :] = r_k[mask] @ x[mask] + + try: + w = np.linalg.solve(phi_t_phi, phi_t_x) + except np.linalg.LinAlgError: + w = np.linalg.lstsq(phi_t_phi, phi_t_x, rcond=None)[0] + + w[0, :] = 0 + w_batch = w[batch_codes + 1, :] + z -= r_k[:, np.newaxis] * w_batch + + return z + + +def _correction_fast( + x: np.ndarray, + batch_codes: np.ndarray, + n_batches: int, + r: np.ndarray, + o: np.ndarray, + *, + lambda_kb: np.ndarray, +) -> np.ndarray: + """Fast correction method using precomputed factors.""" + _, d = x.shape + + z = x.copy() + dtype = x.dtype + p = np.eye(n_batches + 1, dtype=dtype) + + for k_idx, (o_k, r_k) in enumerate(zip(o.T, r.T, strict=True)): + lam_k = lambda_kb[:, k_idx] + + factor = (1.0 / (o_k + lam_k)).astype(dtype) + c = np.sum(o_k) + np.sum(-factor * o_k**2) + c_inv = dtype.type(1.0 / c) + + p[0, 1:] = -factor * o_k + + p_t_b_inv = np.zeros((n_batches + 1, n_batches + 1), dtype=dtype) + p_t_b_inv[0, 0] = c_inv + p_t_b_inv[1:, 1:] = np.diag(factor) + p_t_b_inv[1:, 0] = p[0, 1:] * c_inv + + inv_mat = p_t_b_inv @ p + + phi_t_x = np.zeros((n_batches + 1, d), dtype=x.dtype) + phi_t_x[0, :] = r_k @ x + for b in range(n_batches): + mask = batch_codes == b + phi_t_x[b + 1, :] = r_k[mask] @ x[mask] + + w = inv_mat @ phi_t_x + w[0, :] = 0 + + w_batch = w[batch_codes + 1, :] + z -= r_k[:, np.newaxis] * w_batch + + return z + + +def _compute_objective( + y_norm: np.ndarray, + z_norm: np.ndarray, + r: np.ndarray, + *, + theta: np.ndarray, + sigma: float, + o: np.ndarray, + e: np.ndarray, + stabilized_penalty: bool = True, +) -> float: + """Compute Harmony objective function.""" + zy = z_norm @ y_norm.T + kmeans_error = np.sum(r * 2.0 * (1.0 - zy)) + + r_row_sums = r.sum(axis=1, keepdims=True) + r_normalized = r / np.clip(r_row_sums, 1e-12, None) + entropy = sigma * np.sum(r_normalized * np.log(r_normalized + 1e-12)) + + if stabilized_penalty: + # Harmony2: numerator is (O + E + 1) + log_ratio = np.log((o + e + 1) / (e + 1)) + else: + # Harmony1: numerator is (O + 1) + log_ratio = np.log((o + 1) / (e + 1)) + diversity_penalty = sigma * np.sum(theta @ (o * log_ratio)) + + return kmeans_error + entropy + diversity_penalty + + +def _is_convergent_clustering( + objectives: list, + tol: float, + window_size: int = 3, +) -> bool: + """Check clustering convergence using window.""" + if len(objectives) < window_size + 1: + return False + + obj_old = sum(objectives[-window_size - 1 : -1]) + obj_new = sum(objectives[-window_size:]) + + return (obj_old - obj_new) < tol * abs(obj_old) diff --git a/src/scanpy/preprocessing/_harmony_integrate.py b/src/scanpy/preprocessing/_harmony_integrate.py deleted file mode 100644 index 474ee47909..0000000000 --- a/src/scanpy/preprocessing/_harmony_integrate.py +++ /dev/null @@ -1,228 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import numpy as np - -from .._compat import warn - -if TYPE_CHECKING: - from collections.abc import Sequence - from typing import Literal - - from anndata import AnnData - from numpy.typing import DTypeLike - - -def harmony_integrate( # noqa: PLR0913 - adata: AnnData, - key: str | Sequence[str], - *, - basis: str = "X_pca", - adjusted_basis: str = "X_pca_harmony", - dtype: DTypeLike = np.float64, - flavor: Literal["harmony2", "harmony1"] = "harmony2", - n_clusters: int | None = None, - max_iter_harmony: int = 10, - max_iter_clustering: int = 200, - tol_harmony: float = 1e-4, - tol_clustering: float = 1e-5, - sigma: float = 0.1, - theta: float | Sequence[float] = 2.0, - tau: int = 0, - ridge_lambda: float = 1.0, - alpha: float = 0.2, - batch_prune_threshold: float | None = 1e-5, - correction_method: Literal["fast", "original"] = "original", - block_proportion: float = 0.05, - random_state: int | None = 0, -) -> None: - """Integrate different experiments using the Harmony algorithm :cite:p:`Korsunsky2019,Patikas2026`. - - This CPU implementation is based on the harmony-pytorch & rapids_singlecell - version, using NumPy for efficient computation. As Harmony works by adjusting - the principal components, this function should be run after performing PCA but - before computing the neighbor graph. - - By default, the Harmony2 algorithm is used, which includes a stabilized - diversity penalty, dynamic per-cluster-per-batch ridge regularization, - and automatic batch pruning. To revert to the original Harmony behavior:: - - sc.pp.harmony_integrate(adata, key, flavor="harmony1") - - Parameters - ---------- - adata - The annotated data matrix. - key - The key(s) of the column(s) in ``adata.obs`` that differentiates - among experiments/batches. When multiple keys are provided, a - combined batch variable is created from all columns. - basis - The name of the field in ``adata.obsm`` where the PCA table is - stored. Defaults to ``'X_pca'``. - adjusted_basis - The name of the field in ``adata.obsm`` where the adjusted PCA - table will be stored. Defaults to ``X_pca_harmony``. - dtype - The data type to use for Harmony computation. If you use 32-bit - you may experience numerical instability. - flavor - Which version of the Harmony algorithm to use. - ``"harmony2"`` (default) enables the stabilized diversity penalty, - dynamic per-cluster-per-batch ridge regularization, and automatic - batch pruning from :cite:p:`Patikas2026`. - ``"harmony1"`` uses the original algorithm from - :cite:p:`Korsunsky2019`. - n_clusters - Number of clusters used for soft k-means in the Harmony algorithm. - If ``None``, uses ``min(100, N / 30)``. More clusters capture - finer-grained structure but increase computation time. - max_iter_harmony - Maximum number of outer Harmony iterations (each consisting of - a clustering step followed by a correction step). - max_iter_clustering - Maximum iterations for the clustering step within each Harmony - iteration. - tol_harmony - Convergence tolerance for the Harmony objective function. - The algorithm stops when the relative change in objective falls - below this value. - tol_clustering - Convergence tolerance for the clustering step within each - Harmony iteration. - sigma - Width of the soft-clustering kernel. Controls the entropy of - cluster assignments: smaller values produce harder assignments - (cells assigned to fewer clusters), while larger values produce - softer assignments (cells spread across more clusters). - theta - Diversity penalty weight per batch variable. Controls how - strongly Harmony encourages each cluster to contain a balanced - representation of all batches. Higher values (e.g. ``4``) - produce more aggressive mixing; lower values (e.g. ``0.5``) - allow more batch-specific clusters. Set to ``0`` to disable - batch correction entirely. A list can be provided to set - different weights per batch variable. - tau - Discounting factor on ``theta``. When ``tau > 0``, the - diversity penalty is down-weighted for batches with fewer cells, - preventing over-correction of small batches. By default (``0``), - there is no discounting. - ridge_lambda - Ridge regression regularization for the correction step. - Larger values produce more conservative (smaller) corrections, - preventing over-fitting. Only used with ``flavor="harmony1"``. - alpha - Scaling factor for the dynamic per-cluster-per-batch ridge - regularization. The effective regularization for each - cluster-batch pair is ``alpha * E_kb`` where ``E_kb`` is the - expected number of cells. Larger values produce more - conservative corrections. Only used with ``flavor="harmony2"``. - batch_prune_threshold - Fraction threshold below which a batch-cluster pair is pruned - (correction suppressed). When the fraction of a batch's cells - assigned to a cluster (``O_kb / N_b``) falls below this - threshold, that batch-cluster pair receives no correction, - preventing spurious adjustments. Only used with - ``flavor="harmony2"``. Set to ``None`` to disable pruning. - correction_method - Method for the correction step. ``"original"`` uses per-cluster - ridge regression with explicit matrix inversion. ``"fast"`` - uses a precomputed factorization that avoids the full inversion, - which can be faster for datasets with many batches. - block_proportion - Proportion of cells updated per clustering sub-iteration. - Smaller values produce more stochastic updates. Larger values - are faster but may converge to different solutions. - random_state - Random seed for reproducibility. - - Returns - ------- - Updates adata with the field ``adata.obsm[adjusted_basis]``, \ - containing principal components adjusted by Harmony such that \ - different experiments are integrated. - """ - from ._harmony import Harmony - - # Resolve flavor into internal flags - if flavor not in {"harmony1", "harmony2"}: - msg = f"flavor must be 'harmony1' or 'harmony2', got {flavor!r}." - raise ValueError(msg) - stabilized_penalty = flavor == "harmony2" - dynamic_lambda = flavor == "harmony2" - - # Warn when flavor-incompatible parameters are explicitly set - if flavor == "harmony2" and ridge_lambda != 1.0: - warn( - "ridge_lambda is ignored when flavor='harmony2'; " - "use alpha to control regularization strength.", - UserWarning, - ) - if flavor == "harmony1": - if alpha != 0.2: - warn( - "alpha is ignored when flavor='harmony1'; use ridge_lambda instead.", - UserWarning, - ) - if batch_prune_threshold != 1e-5: - warn( - "batch_prune_threshold is ignored when flavor='harmony1'.", - UserWarning, - ) - - # Ensure the basis exists in adata.obsm - if basis not in adata.obsm: - msg = ( - f"The specified basis {basis!r} is not available in `adata.obsm`. " - f"Available bases: {list(adata.obsm.keys())}" - ) - raise ValueError(msg) - - # Get the input data - input_data = adata.obsm[basis] - - # Convert to numpy array with specified dtype - try: - x = np.ascontiguousarray(input_data, dtype=dtype) - except Exception as e: - msg = ( - f"Could not convert input of type {type(input_data).__name__} " - "to NumPy array." - ) - raise TypeError(msg) from e - - # Check for NaN values - if np.isnan(x).any(): - msg = ( - "Input data contains NaN values. Please handle these before " - "running harmony_integrate." - ) - raise ValueError(msg) - - # Run Harmony - harmony = Harmony( - adata.obs, - key, - theta=theta, - sigma=sigma, - n_clusters=n_clusters, - max_iter_harmony=max_iter_harmony, - max_iter_clustering=max_iter_clustering, - tol_harmony=tol_harmony, - tol_clustering=tol_clustering, - ridge_lambda=ridge_lambda, - correction_method=correction_method, - block_proportion=block_proportion, - tau=tau, - random_state=random_state, - stabilized_penalty=stabilized_penalty, - dynamic_lambda=dynamic_lambda, - alpha=alpha, - batch_prune_threshold=batch_prune_threshold, - ) - harmony_out = harmony.fit(x) - - # Store result - adata.obsm[adjusted_basis] = harmony_out diff --git a/tests/test_harmony.py b/tests/test_harmony.py index 857cc3f155..a880a29f5b 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -10,7 +10,7 @@ from scipy.stats import pearsonr from scanpy.preprocessing import harmony_integrate -from scanpy.preprocessing._harmony import _SUPPRESS_PENALTY, _compute_lambda_kb +from scanpy.preprocessing._harmony.core import _SUPPRESS_PENALTY, _compute_lambda_kb from testing.scanpy._helpers.data import pbmc68k_reduced if TYPE_CHECKING: @@ -282,7 +282,7 @@ def test_harmony_flavor_warnings() -> None: @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_compute_lambda_kb_pruning(dtype: DTypeLike) -> None: +def test_compute_lambda_kb_pruning(dtype: type[np.floating]) -> None: """_compute_lambda_kb suppresses correction for N_b==0 and below-threshold pairs.""" n_batches, n_clusters = 4, 3 alpha = 0.2 @@ -320,7 +320,7 @@ def test_compute_lambda_kb_pruning(dtype: DTypeLike) -> None: @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_compute_lambda_kb_dynamic_false(dtype: DTypeLike) -> None: +def test_compute_lambda_kb_dynamic_false(dtype: type[np.floating]) -> None: """_compute_lambda_kb returns uniform ridge_lambda when dynamic_lambda=False.""" n_batches, n_clusters = 3, 5 e = np.ones((n_batches, n_clusters), dtype=dtype) @@ -340,7 +340,7 @@ def test_compute_lambda_kb_dynamic_false(dtype: DTypeLike) -> None: @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_compute_lambda_kb_zero_denom(dtype: DTypeLike) -> None: +def test_compute_lambda_kb_zero_denom(dtype: type[np.floating]) -> None: """_compute_lambda_kb guards against O==0 and E==0 (zero-denominator).""" sentinel = dtype(_SUPPRESS_PENALTY) e = np.array([[0.0, 5.0]], dtype=dtype) From 4257d8fe277c9112530cc6f5fef8e815e9d0cd31 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 20 Apr 2026 16:16:44 +0200 Subject: [PATCH 14/14] rng --- src/scanpy/preprocessing/_harmony/__init__.py | 10 ++++--- src/scanpy/preprocessing/_harmony/core.py | 26 +++++++++++++------ tests/test_harmony.py | 6 ++--- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/scanpy/preprocessing/_harmony/__init__.py b/src/scanpy/preprocessing/_harmony/__init__.py index 091c84ba51..66a1d7ce75 100644 --- a/src/scanpy/preprocessing/_harmony/__init__.py +++ b/src/scanpy/preprocessing/_harmony/__init__.py @@ -13,6 +13,8 @@ from anndata import AnnData from numpy.typing import DTypeLike + from ..._utils.random import RNGLike, SeedLike + def harmony_integrate( # noqa: PLR0913 adata: AnnData, @@ -35,7 +37,7 @@ def harmony_integrate( # noqa: PLR0913 batch_prune_threshold: float | None = 1e-5, correction_method: Literal["fast", "original"] = "original", block_proportion: float = 0.05, - random_state: int | None = 0, + rng: SeedLike | RNGLike | None = None, ) -> None: """Integrate different experiments using the Harmony algorithm :cite:p:`Korsunsky2019,Patikas2026`. @@ -135,8 +137,8 @@ def harmony_integrate( # noqa: PLR0913 Proportion of cells updated per clustering sub-iteration. Smaller values produce more stochastic updates. Larger values are faster but may converge to different solutions. - random_state - Random seed for reproducibility. + rng + Random number generator or seed for deterministic behavior. Returns ------- @@ -216,7 +218,7 @@ def harmony_integrate( # noqa: PLR0913 correction_method=correction_method, block_proportion=block_proportion, tau=tau, - random_state=random_state, + rng=rng, stabilized_penalty=stabilized_penalty, dynamic_lambda=dynamic_lambda, alpha=alpha, diff --git a/src/scanpy/preprocessing/_harmony/core.py b/src/scanpy/preprocessing/_harmony/core.py index cf48ec9369..cc127d9118 100644 --- a/src/scanpy/preprocessing/_harmony/core.py +++ b/src/scanpy/preprocessing/_harmony/core.py @@ -8,6 +8,8 @@ from sklearn.cluster import KMeans from tqdm.auto import tqdm +from scanpy._utils.random import _legacy_random_state + from ... import logging as log from ..._settings import settings from ..._settings.verbosity import Verbosity @@ -18,6 +20,8 @@ import pandas as pd + from ..._utils.random import RNGLike, SeedLike + __all__ = ["_SUPPRESS_PENALTY", "Harmony", "_compute_lambda_kb"] @@ -53,7 +57,7 @@ class Harmony: correction_method: Literal["fast", "original"] block_proportion: float tau: int - random_state: int | None + rng: InitVar[SeedLike | RNGLike | None] stabilized_penalty: bool = True dynamic_lambda: bool = True alpha: float = 0.2 @@ -61,10 +65,16 @@ class Harmony: batch_codes: np.ndarray = field(init=False) n_batches: int = field(init=False) + _rng: np.random.Generator = field(init=False) def __post_init__( - self, batch_df: pd.DataFrame, batch_key: str | Sequence[str] + self, + batch_df: pd.DataFrame, + batch_key: str | Sequence[str], + rng: SeedLike | RNGLike | None, ) -> None: + self._rng = np.random.default_rng(rng) + if self.max_iter_harmony < 1: msg = "max_iter_harmony must be >= 1" raise ValueError(msg) @@ -90,9 +100,6 @@ def fit(self, x: np.ndarray) -> np.ndarray: z_corr Batch-corrected embedding matrix (n_cells x d). """ - if self.random_state is not None: - np.random.seed(self.random_state) - # Ensure input is C-contiguous float array (infer dtype from x) x = np.ascontiguousarray(x) n_cells = x.shape[0] @@ -131,7 +138,7 @@ def fit(self, x: np.ndarray) -> np.ndarray: n_clusters=n_clusters, sigma=self.sigma, theta=theta_arr, - random_state=self.random_state, + rng=self._rng, stabilized_penalty=self.stabilized_penalty, ) @@ -277,12 +284,15 @@ def _initialize_centroids( n_clusters: int, sigma: float, theta: np.ndarray, - random_state: int | None, + rng: np.random.Generator, stabilized_penalty: bool = True, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, float]: """Initialize cluster centroids using K-means.""" kmeans = KMeans( - n_clusters=n_clusters, random_state=random_state, n_init=10, max_iter=25 + n_clusters=n_clusters, + random_state=_legacy_random_state(rng, always_state=True), + n_init=10, + max_iter=25, ) kmeans.fit(z_norm) diff --git a/tests/test_harmony.py b/tests/test_harmony.py index a880a29f5b..09c661ec23 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -52,7 +52,7 @@ def _get_measure( @pytest.fixture(scope="module") -def _adata_reference() -> AnnData: +def adata_reference_module() -> AnnData: """Load reference data once per module (avoids re-reading CSV).""" paths = { f: pooch.retrieve( @@ -72,9 +72,9 @@ def _adata_reference() -> AnnData: @pytest.fixture -def adata_reference(_adata_reference: AnnData) -> AnnData: +def adata_reference(adata_reference_module: AnnData) -> AnnData: """Return a fresh copy per test so tests don't mutate shared state.""" - return _adata_reference.copy() + return adata_reference_module.copy() @pytest.mark.parametrize("correction_method", ["fast", "original"])