From 8e858b0029008db8172d4c0db3ac35ed2267ef09 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Thu, 9 Apr 2026 15:47:26 +0200 Subject: [PATCH 01/21] rebase to main and start with illico optimization --- .../tools/_rank_genes_groups/_wilcoxon.py | 585 ++++++++++++++---- 1 file changed, 472 insertions(+), 113 deletions(-) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index c14c760d..f24da8f2 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -4,14 +4,14 @@ from typing import TYPE_CHECKING import cupy as cp +import cupyx.scipy.sparse as cpsp import cupyx.scipy.special as cupyx_special import numpy as np import scipy.sparse as sp from rapids_singlecell._cuda import _wilcoxon_cuda as _wc -from rapids_singlecell._utils._csr_to_csc import _fast_csr_to_csc -from ._utils import _choose_chunk_size, _get_column_block +from ._utils import _choose_chunk_size if TYPE_CHECKING: from numpy.typing import NDArray @@ -20,72 +20,395 @@ MIN_GROUP_SIZE_WARNING = 25 - -def _average_ranks( - matrix: cp.ndarray, *, return_sorted: bool = False -) -> cp.ndarray | tuple[cp.ndarray, cp.ndarray]: +# --------------------------------------------------------------------------- +# CuPy RawKernels for sort-once OVO +# --------------------------------------------------------------------------- + +_RANK_SUMS_KERNEL = cp.RawKernel( + r""" +extern "C" __global__ +void rank_sums_from_sorted( + const double* __restrict__ ref_sorted, // (n_ref, n_cols) F-order + const double* __restrict__ grp_sorted, // (n_grp, n_cols) F-order + double* __restrict__ rank_sums, // (n_cols,) + const int n_ref, + const int n_grp, + const int n_cols +) { + /* One block per gene (column). + Threads cooperatively process group elements. + For each group element, binary-search the sorted reference + and the sorted group to compute the average rank in the + combined (group + reference) set. + */ + int col = blockIdx.x; + if (col >= n_cols) return; + + const double* ref = ref_sorted + (long long)col * n_ref; + const double* grp = grp_sorted + (long long)col * n_grp; + + double local_sum = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + double v = grp[i]; + + // --- count of ref values < v --- + int lo = 0, hi = n_ref; + while (lo < hi) { int m = (lo + hi) >> 1; if (ref[m] < v) lo = m + 1; else hi = m; } + int n_lt_ref = lo; + + // --- count of ref values <= v --- + lo = n_lt_ref; hi = n_ref; + while (lo < hi) { int m = (lo + hi) >> 1; if (ref[m] <= v) lo = m + 1; else hi = m; } + int n_eq_ref = lo - n_lt_ref; + + // --- count of grp values < v --- + lo = 0; hi = n_grp; + while (lo < hi) { int m = (lo + hi) >> 1; if (grp[m] < v) lo = m + 1; else hi = m; } + int n_lt_grp = lo; + + // --- count of grp values <= v --- + lo = n_lt_grp; hi = n_grp; + while (lo < hi) { int m = (lo + hi) >> 1; if (grp[m] <= v) lo = m + 1; else hi = m; } + int n_eq_grp = lo - n_lt_grp; + + int n_lt = n_lt_ref + n_lt_grp; + int n_eq = n_eq_ref + n_eq_grp; + double avg_rank = (double)n_lt + ((double)n_eq + 1.0) / 2.0; + local_sum += avg_rank; + } + + // --- warp-level reduction --- + #pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_sum += __shfl_down_sync(0xffffffff, local_sum, off); + + __shared__ double warp_sums[32]; + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_sums[wid] = local_sum; + __syncthreads(); + + if (threadIdx.x < 32) { + double val = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_sums[threadIdx.x] : 0.0; + #pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + if (threadIdx.x == 0) rank_sums[col] = val; + } +} +""", + "rank_sums_from_sorted", + options=("--use_fast_math",), +) + + +_TIE_CORR_MERGE_KERNEL = cp.RawKernel( + r""" +extern "C" __global__ +void tie_correction_merged( + const double* __restrict__ ref_sorted, + const double* __restrict__ grp_sorted, + double* __restrict__ correction, + const int n_ref, + const int n_grp, + const int n_cols +) { + /* One block per gene column. Thread 0 merges the two sorted + arrays and accumulates the tie-correction term + sum(t^3 - t) over all tie groups of size t. + */ + int col = blockIdx.x; + if (col >= n_cols || threadIdx.x != 0) return; + + const double* ref = ref_sorted + (long long)col * n_ref; + const double* grp = grp_sorted + (long long)col * n_grp; + + int i = 0, j = 0; + double tie_sum = 0.0; + + while (i < n_ref || j < n_grp) { + double v; + if (j >= n_grp) v = ref[i]; + else if (i >= n_ref) v = grp[j]; + else v = (ref[i] <= grp[j]) ? ref[i] : grp[j]; + + int count = 0; + while (i < n_ref && ref[i] == v) { ++i; ++count; } + while (j < n_grp && grp[j] == v) { ++j; ++count; } + + if (count > 1) { + double t = (double)count; + tie_sum += t * t * t - t; + } + } + + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + correction[col] = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; +} +""", + "tie_correction_merged", + options=("--use_fast_math",), +) + + +_CSR_EXTRACT_KERNEL = cp.RawKernel( + r""" +extern "C" __global__ +void csr_extract_dense( + const double* __restrict__ data, + const int* __restrict__ indices, + const long long* __restrict__ indptr, + const int* __restrict__ row_ids, + double* __restrict__ out, // F-order (n_target, n_cols) + const int n_target, + const int col_start, + const int col_stop, + const int n_cols // = col_stop - col_start +) { + /* One thread per target row. + Binary-search the CSR index array for col_start, then + linear-scan through [col_start, col_stop) writing to + the dense output in column-major order. + */ + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_target) return; + + int row = row_ids[tid]; + long long rs = indptr[row]; + long long re = indptr[row + 1]; + + // binary search for col_start + long long lo = rs, hi = re; + while (lo < hi) { + long long m = (lo + hi) >> 1; + if (indices[m] < col_start) lo = m + 1; else hi = m; + } + + for (long long p = lo; p < re; ++p) { + int c = indices[p]; + if (c >= col_stop) break; + int lc = c - col_start; + out[(long long)lc * n_target + tid] = data[p]; + } +} +""", + "csr_extract_dense", + options=("--use_fast_math",), +) + + +# --------------------------------------------------------------------------- +# Kernel helpers +# --------------------------------------------------------------------------- + +WARP_SIZE = 32 +MAX_THREADS = 512 + + +def _round_up_to_warp(n: int) -> int: + return min(MAX_THREADS, ((n + WARP_SIZE - 1) // WARP_SIZE) * WARP_SIZE) + + +def _rank_sums_searchsorted( + ref_sorted: cp.ndarray, + grp_sorted: cp.ndarray, +) -> cp.ndarray: + """Rank sums for *grp* via binary search in pre-sorted *ref*. + + Both must be F-order float64 ``(n_rows, n_cols)``. """ - Compute average ranks for each column using GPU kernel. + n_ref, n_cols = ref_sorted.shape + n_grp = grp_sorted.shape[0] + rank_sums = cp.empty(n_cols, dtype=cp.float64) + threads = _round_up_to_warp(min(n_grp, MAX_THREADS)) + _RANK_SUMS_KERNEL( + (n_cols,), + (threads,), + ( + ref_sorted, + grp_sorted, + rank_sums, + np.int32(n_ref), + np.int32(n_grp), + np.int32(n_cols), + ), + stream=cp.cuda.get_current_stream(), + ) + return rank_sums + + +def _tie_correction_merged( + ref_sorted: cp.ndarray, + grp_sorted: cp.ndarray, +) -> cp.ndarray: + """Tie-correction factor via merge of two sorted F-order arrays.""" + n_ref, n_cols = ref_sorted.shape + n_grp = grp_sorted.shape[0] + correction = cp.empty(n_cols, dtype=cp.float64) + _TIE_CORR_MERGE_KERNEL( + (n_cols,), + (1,), + ( + ref_sorted, + grp_sorted, + correction, + np.int32(n_ref), + np.int32(n_grp), + np.int32(n_cols), + ), + stream=cp.cuda.get_current_stream(), + ) + return correction + + +def _extract_dense_block_csr_gpu( + data: cp.ndarray, + indices: cp.ndarray, + indptr: cp.ndarray, + row_ids: cp.ndarray, + *, + col_start: int, + col_stop: int, +) -> cp.ndarray: + """Extract a dense F-order float64 block from GPU CSR arrays.""" + n_target = row_ids.shape[0] + n_cols = col_stop - col_start + out = cp.zeros((n_target, n_cols), dtype=cp.float64, order="F") + if n_target == 0 or n_cols == 0: + return out + threads = _round_up_to_warp(min(n_target, MAX_THREADS)) + blocks = (n_target + threads - 1) // threads + _CSR_EXTRACT_KERNEL( + (blocks,), + (threads,), + ( + data, + indices, + indptr, + row_ids, + out, + np.int32(n_target), + np.int32(col_start), + np.int32(col_stop), + np.int32(n_cols), + ), + stream=cp.cuda.get_current_stream(), + ) + return out + + +def _to_gpu_csr_arrays(X) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: + """Return (data, indices, indptr) as float64/int32/int64 on GPU.""" + if isinstance(X, cpsp.csr_matrix): + csr = X + elif isinstance(X, cpsp.csc_matrix): + csr = X.tocsr() + elif isinstance(X, sp.spmatrix | sp.sparray): + if X.format != "csr": + X = X.tocsr() + csr = cpsp.csr_matrix(X) + else: + raise TypeError(f"Expected sparse matrix, got {type(X)}") + return ( + csr.data.astype(cp.float64, copy=False), + csr.indices.astype(cp.int32, copy=False), + csr.indptr.astype(cp.int64, copy=False), + ) - Uses scipy.stats.rankdata 'average' method: ties get the average - of the ranks they would span. - Parameters - ---------- - matrix - Input matrix (n_rows, n_cols) - return_sorted - If True, also return sorted values (useful for tie correction) +def _extract_dense_block( + X, + row_ids: cp.ndarray | None, + start: int, + stop: int, + *, + csr_arrays: tuple[cp.ndarray, cp.ndarray, cp.ndarray] | None = None, +) -> cp.ndarray: + """Extract ``X[row_ids, start:stop]`` as dense F-order float64 on GPU.""" + if csr_arrays is not None: + data, indices, indptr = csr_arrays + if row_ids is None: + n_target = int(indptr.shape[0] - 1) + row_ids = cp.arange(n_target, dtype=cp.int32) + return _extract_dense_block_csr_gpu( + data, indices, indptr, row_ids, col_start=start, col_stop=stop + ) - Returns - ------- - ranks or (ranks, sorted_vals) - """ - n_rows, n_cols = matrix.shape + if isinstance(X, np.ndarray): + if row_ids is not None: + return cp.asarray( + X[cp.asnumpy(row_ids), start:stop], dtype=cp.float64, order="F" + ) + return cp.asarray(X[:, start:stop], dtype=cp.float64, order="F") - # Sort each column + if isinstance(X, cp.ndarray): + chunk = X[row_ids, start:stop] if row_ids is not None else X[:, start:stop] + return cp.asfortranarray(chunk.astype(cp.float64, copy=False)) + + if isinstance(X, sp.spmatrix | sp.sparray): + if row_ids is not None: + idx = cp.asnumpy(row_ids) + chunk = X[idx][:, start:stop].toarray() + else: + chunk = X[:, start:stop].toarray() + return cp.asarray(chunk, dtype=cp.float64, order="F") + + if cpsp.issparse(X): + if row_ids is not None: + chunk = X[row_ids][:, start:stop].toarray() + else: + chunk = X[:, start:stop].toarray() + return cp.asfortranarray(chunk.astype(cp.float64, copy=False)) + + raise TypeError(f"Unsupported matrix type: {type(X)}") + + +# --------------------------------------------------------------------------- +# Existing kernels (OVR path) +# --------------------------------------------------------------------------- + + +def _average_ranks( + matrix: cp.ndarray, *, return_sorted: bool = False +) -> cp.ndarray | tuple[cp.ndarray, cp.ndarray]: + """Compute average ranks for each column using GPU kernel.""" + n_rows, n_cols = matrix.shape sorter = cp.argsort(matrix, axis=0) sorted_vals = cp.take_along_axis(matrix, sorter, axis=0) - - # Ensure F-order for kernel (columns contiguous in memory) sorted_vals = cp.asfortranarray(sorted_vals) sorter = cp.asfortranarray(sorter.astype(cp.int32)) - stream = cp.cuda.get_current_stream().ptr _wc.average_rank( sorted_vals, sorter, matrix, n_rows=n_rows, n_cols=n_cols, stream=stream ) - if return_sorted: return matrix, sorted_vals return matrix def _tie_correction(sorted_vals: cp.ndarray) -> cp.ndarray: - """ - Compute tie correction factor for Wilcoxon test. - - Takes pre-sorted values (column-wise) to avoid re-sorting. - Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) - where t is the count of tied values. - """ + """Tie correction factor from pre-sorted values (F-order).""" n_rows, n_cols = sorted_vals.shape correction = cp.ones(n_cols, dtype=cp.float64) - if n_rows < 2: return correction - - # Ensure F-order sorted_vals = cp.asfortranarray(sorted_vals) - stream = cp.cuda.get_current_stream().ptr _wc.tie_correction( sorted_vals, correction, n_rows=n_rows, n_cols=n_cols, stream=stream ) - return correction +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + def wilcoxon( rg: _RankGenes, *, @@ -94,14 +417,12 @@ def wilcoxon( chunk_size: int | None = None, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" - # Compute basic stats - uses Aggregate if on GPU, else defers to chunks rg._basic_stats() X = rg.X n_cells, n_total_genes = rg.X.shape group_sizes = rg.group_sizes if rg.ireference is not None: - # Compare each group against a specific reference group return _wilcoxon_with_reference( rg, X, @@ -111,7 +432,6 @@ def wilcoxon( use_continuity=use_continuity, chunk_size=chunk_size, ) - # Compare each group against "rest" (all other cells) return _wilcoxon_vs_rest( rg, X, @@ -124,6 +444,11 @@ def wilcoxon( ) +# --------------------------------------------------------------------------- +# One-vs-rest (unchanged from main) +# --------------------------------------------------------------------------- + + def _wilcoxon_vs_rest( rg: _RankGenes, X, @@ -136,9 +461,12 @@ def _wilcoxon_vs_rest( chunk_size: int | None, ) -> list[tuple[int, NDArray, NDArray]]: """Wilcoxon test: each group vs rest of cells.""" + from rapids_singlecell._utils._csr_to_csc import _fast_csr_to_csc + + from ._utils import _get_column_block + n_groups = len(rg.groups_order) - # Warn for small groups for name, size in zip(rg.groups_order, group_sizes, strict=True): rest = n_cells - size if size <= MIN_GROUP_SIZE_WARNING or rest <= MIN_GROUP_SIZE_WARNING: @@ -149,7 +477,6 @@ def _wilcoxon_vs_rest( stacklevel=4, ) - # Build one-hot indicator matrix from group codes codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int64) group_matrix = cp.zeros((n_cells, n_groups), dtype=cp.float64) valid_idx = cp.where(codes_gpu < n_groups)[0] @@ -160,22 +487,17 @@ def _wilcoxon_vs_rest( chunk_width = _choose_chunk_size(chunk_size) - # Accumulate results per group all_scores: dict[int, list] = {i: [] for i in range(n_groups)} all_pvals: dict[int, list] = {i: [] for i in range(n_groups)} - # One-time CSR->CSC via fast parallel Numba kernel; _get_column_block - # then uses direct indptr pointer copy for each chunk. if isinstance(X, sp.spmatrix | sp.sparray): X = _fast_csr_to_csc(X) if X.format == "csr" else X.tocsc() for start in range(0, n_total_genes, chunk_width): stop = min(start + chunk_width, n_total_genes) - # Slice and convert to dense GPU array (F-order for column ops) block = _get_column_block(X, start, stop) - # Accumulate stats for this chunk rg._accumulate_chunk_stats_vs_rest( block, start, @@ -211,13 +533,17 @@ def _wilcoxon_vs_rest( all_scores[idx].append(z_host[idx]) all_pvals[idx].append(p_host[idx]) - # Collect results per group return [ (gi, np.concatenate(all_scores[gi]), np.concatenate(all_pvals[gi])) for gi in range(n_groups) ] +# --------------------------------------------------------------------------- +# One-vs-reference (sort-once optimisation inspired by illico) +# --------------------------------------------------------------------------- + + def _wilcoxon_with_reference( rg: _RankGenes, X, @@ -228,97 +554,130 @@ def _wilcoxon_with_reference( use_continuity: bool, chunk_size: int | None, ) -> list[tuple[int, NDArray, NDArray]]: - """Wilcoxon test: each group vs a specific reference group.""" - codes = rg.group_codes - n_ref = int(group_sizes[rg.ireference]) - mask_ref = codes == rg.ireference + """Wilcoxon test: each group vs a specific reference group. + + Key optimisations over the naive per-group approach: + + * **No CSR->CSC conversion** -- data is read directly from CSR via a + binary-search extraction kernel. + * **Reference sorted once per gene chunk** -- the (typically large) + reference group is extracted and column-sorted once, then reused + for every test-group comparison. + * **Rank sums via binary search** -- instead of concatenating and + re-sorting reference + group for every pair, a GPU kernel computes + rank sums by binary-searching the pre-sorted reference. This + reduces the per-group cost from O((n_ref+n_grp) log(n_ref+n_grp)) + to O(n_grp log(n_ref)). + """ + n_groups = len(rg.groups_order) + ireference = rg.ireference + n_ref = int(group_sizes[ireference]) - results: list[tuple[int, NDArray, NDArray]] = [] + # ---- build row-index arrays (GPU int32) for every group ---- + codes = rg.group_codes + ref_row_ids = cp.asarray(np.where(codes == ireference)[0], dtype=cp.int32) - for group_index in range(len(rg.groups_order)): - if group_index == rg.ireference: + group_row_ids: dict[int, cp.ndarray] = {} + for gi in range(n_groups): + if gi == ireference: continue + group_row_ids[gi] = cp.asarray(np.where(codes == gi)[0], dtype=cp.int32) - n_group = int(group_sizes[group_index]) - n_combined = n_group + n_ref + # ---- prepare CSR arrays on GPU if sparse (one-time transfer) ---- + csr_arrays = None + if sp.issparse(X) or cpsp.issparse(X): + csr_arrays = _to_gpu_csr_arrays(X) - # Warn for small groups + # ---- warn for small groups ---- + for gi in group_row_ids: + n_group = int(group_sizes[gi]) if n_group <= MIN_GROUP_SIZE_WARNING or n_ref <= MIN_GROUP_SIZE_WARNING: warnings.warn( - f"Group {rg.groups_order[group_index]} has size {n_group} " + f"Group {rg.groups_order[gi]} has size {n_group} " f"(reference {n_ref}); normal approximation " "of the Wilcoxon statistic may be inaccurate.", RuntimeWarning, stacklevel=4, ) - # Combined mask: group + reference - mask_obs = codes == group_index - mask_combined = mask_obs | mask_ref - - # Subset matrix ONCE before chunking (10x faster than filtering each chunk) - X_subset = X[mask_combined, :] - - # One-time CSR->CSC via fast parallel Numba kernel - if isinstance(X_subset, sp.spmatrix | sp.sparray): - X_subset = ( - _fast_csr_to_csc(X_subset) - if X_subset.format == "csr" - else X_subset.tocsc() - ) - - # Within the combined array, True = group cell, False = reference cell - group_mask_gpu = cp.asarray(mask_obs[mask_combined]) + # ---- pre-allocate outputs ---- + all_scores: dict[int, np.ndarray] = {} + all_pvals: dict[int, np.ndarray] = {} + for gi in group_row_ids: + all_scores[gi] = np.empty(n_total_genes, dtype=np.float64) + all_pvals[gi] = np.empty(n_total_genes, dtype=np.float64) - chunk_width = _choose_chunk_size(chunk_size) - - # Pre-allocate output arrays - scores = np.empty(n_total_genes, dtype=np.float64) - pvals = np.empty(n_total_genes, dtype=np.float64) - - for start in range(0, n_total_genes, chunk_width): - stop = min(start + chunk_width, n_total_genes) + chunk_width = _choose_chunk_size(chunk_size) - # Get block for combined cells only - block = _get_column_block(X_subset, start, stop) + # ---- chunk loop (outer) x group loop (inner) ---- + for start in range(0, n_total_genes, chunk_width): + stop = min(start + chunk_width, n_total_genes) + n_cols = stop - start - # Accumulate stats for this chunk - rg._accumulate_chunk_stats_with_ref( - block, - start, - stop, - group_index=group_index, - group_mask_gpu=group_mask_gpu, - n_group=n_group, - n_ref=n_ref, + # Extract & sort reference columns ONCE per chunk + ref_block = _extract_dense_block( + X, ref_row_ids, start, stop, csr_arrays=csr_arrays + ) + ref_sorted = cp.asfortranarray(cp.sort(ref_block, axis=0)) + + # Accumulate reference stats once per chunk (CPU-data path) + if rg._compute_stats_in_chunks and start not in rg._ref_chunk_computed: + rg._ref_chunk_computed.add(start) + ref_mean = ref_block.mean(axis=0) + rg.means[ireference, start:stop] = cp.asnumpy(ref_mean) + if n_ref > 1: + ref_var = ref_block.var(axis=0, ddof=1) + rg.vars[ireference, start:stop] = cp.asnumpy(ref_var) + if rg.comp_pts: + ref_nnz = (ref_block != 0).sum(axis=0) + rg.pts[ireference, start:stop] = cp.asnumpy(ref_nnz / n_ref) + + for gi, grp_rows in group_row_ids.items(): + n_group = int(group_sizes[gi]) + n_combined = n_group + n_ref + + # Extract & sort group columns (small, fast) + grp_block = _extract_dense_block( + X, grp_rows, start, stop, csr_arrays=csr_arrays ) - - # Ranks for combined group+reference cells + grp_sorted = cp.asfortranarray(cp.sort(grp_block, axis=0)) + + # Accumulate group stats (CPU-data path) + if rg._compute_stats_in_chunks: + grp_mean = grp_block.mean(axis=0) + rg.means[gi, start:stop] = cp.asnumpy(grp_mean) + if n_group > 1: + grp_var = grp_block.var(axis=0, ddof=1) + rg.vars[gi, start:stop] = cp.asnumpy(grp_var) + if rg.comp_pts: + grp_nnz = (grp_block != 0).sum(axis=0) + rg.pts[gi, start:stop] = cp.asnumpy(grp_nnz / n_group) + + # ---- rank sums via binary search (no combined sort) ---- + rank_sums = _rank_sums_searchsorted(ref_sorted, grp_sorted) + + # ---- tie correction (optional) ---- if tie_correct: - ranks, sorted_vals = _average_ranks(block, return_sorted=True) - tie_corr = _tie_correction(sorted_vals) + tie_corr = _tie_correction_merged(ref_sorted, grp_sorted) else: - ranks = _average_ranks(block) - tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64) - - # Rank sum for the group - rank_sums = (ranks * group_mask_gpu[:, None]).sum(axis=0) + tie_corr = cp.ones(n_cols, dtype=cp.float64) - # Wilcoxon z-score formula for two groups + # ---- z-scores & p-values ---- expected = n_group * (n_combined + 1) / 2.0 variance = tie_corr * n_group * n_ref * (n_combined + 1) / 12.0 - std = cp.sqrt(variance) diff = rank_sums - expected if use_continuity: diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / std + z = diff / cp.sqrt(variance) cp.nan_to_num(z, copy=False) p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - # Fill pre-allocated arrays - scores[start:stop] = z.get() - pvals[start:stop] = p_values.get() + all_scores[gi][start:stop] = z.get() + all_pvals[gi][start:stop] = p_values.get() - results.append((group_index, scores, pvals)) - - return results + # ---- return in group order ---- + return [ + (gi, all_scores[gi], all_pvals[gi]) + for gi in range(n_groups) + if gi != ireference + ] From e1ef5474750a612d969151fa3197ffa29824001a Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 10 Apr 2026 15:50:15 +0200 Subject: [PATCH 02/21] v1 illico --- CMakeLists.txt | 1 + src/rapids_singlecell/_cuda/nb_types.h | 7 + .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 226 +-- .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 235 +++ .../_cuda/wilcoxon/wilcoxon.cu | 106 +- .../wilcoxon_streaming/wilcoxon_streaming.cu | 1283 +++++++++++++++++ .../tools/_rank_genes_groups/_utils.py | 15 +- .../tools/_rank_genes_groups/_wilcoxon.py | 887 +++++------- tests/test_rank_genes_groups_wilcoxon.py | 188 +-- 9 files changed, 2108 insertions(+), 840 deletions(-) create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon_streaming/wilcoxon_streaming.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index cacf9849..3c0d2e8a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -100,4 +100,5 @@ if (RSC_BUILD_EXTENSIONS) target_link_libraries(_harmony_correction_batched_cuda PRIVATE CUDA::cublas) # Wilcoxon binned histogram CUDA module add_nb_cuda_module(_wilcoxon_binned_cuda src/rapids_singlecell/_cuda/wilcoxon_binned/wilcoxon_binned.cu) + add_nb_cuda_module(_wilcoxon_streaming_cuda src/rapids_singlecell/_cuda/wilcoxon_streaming/wilcoxon_streaming.cu) endif() diff --git a/src/rapids_singlecell/_cuda/nb_types.h b/src/rapids_singlecell/_cuda/nb_types.h index 905e1e07..4cb10e44 100644 --- a/src/rapids_singlecell/_cuda/nb_types.h +++ b/src/rapids_singlecell/_cuda/nb_types.h @@ -42,6 +42,13 @@ using gpu_array = nb::ndarray; template using gpu_array_contig = nb::ndarray; +// Host (NumPy) array aliases +template +using host_array = nb::ndarray>; + +template +using host_array_2d = nb::ndarray; + // Register bindings for both regular CUDA and managed-memory arrays. // Usage: // template diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index c89d913a..46ff14f0 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -3,142 +3,146 @@ #include /** - * Kernel to compute tie correction factor for Wilcoxon test. - * Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) where t is the count of tied - * values. + * Fused rank-sum kernel: walk sorted data, compute per-group rank sums + * and tie correction without materializing a rank matrix. * - * Each block handles one column. Uses binary search to find tie groups. - * Assumes input is sorted column-wise (F-order). + * Each thread processes a CONTIGUOUS chunk of sorted elements, detecting + * tie groups by adjacent comparison (sequential access, no binary search). + * Cross-boundary ties are resolved via binary search at chunk boundaries. + * + * Used by the OVR streaming pipeline in wilcoxon_streaming.cu. */ -__global__ void tie_correction_kernel(const double* __restrict__ sorted_vals, - double* __restrict__ correction, - const int n_rows, const int n_cols) { - // Each block handles one column +__global__ void rank_sums_from_sorted_kernel( + const float* __restrict__ sorted_vals, // F-order (n_rows, n_cols) + const int* __restrict__ sorted_row_idx, // F-order (n_rows, n_cols) + const int* __restrict__ group_codes, // (n_rows_total,) + double* __restrict__ rank_sums, // (n_groups, n_cols) row-major + double* __restrict__ tie_corr, // (n_cols,) + double* __restrict__ group_sums, // (n_groups, n_cols) or NULL + double* __restrict__ group_sq_sums, // (n_groups, n_cols) or NULL + double* __restrict__ group_nnz, // (n_groups, n_cols) or NULL + int n_rows, int n_cols, int n_groups, bool compute_tie_corr, + bool compute_stats) { int col = blockIdx.x; if (col >= n_cols) return; - const double* sv = sorted_vals + (size_t)col * n_rows; + extern __shared__ double smem[]; + double* grp_sums = smem; + double* s_sum = smem + n_groups; + double* s_sq = smem + 2 * n_groups; + double* s_nnz = smem + 3 * n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; + if (compute_stats) { + s_sum[g] = 0.0; + s_sq[g] = 0.0; + s_nnz[g] = 0.0; + } + } + __syncthreads(); + + const float* sv = sorted_vals + (size_t)col * n_rows; + const int* si = sorted_row_idx + (size_t)col * n_rows; - double local_sum = 0.0; - int tid = threadIdx.x; + int chunk = (n_rows + blockDim.x - 1) / blockDim.x; + int my_start = threadIdx.x * chunk; + int my_end = my_start + chunk; + if (my_end > n_rows) my_end = n_rows; - // Each thread processes positions where it detects END of a tie group - // Start from index 1, check if sv[i-1] != sv[i] (boundary detected) - // When at boundary, use binary search to find tie group size - for (int i = tid + 1; i <= n_rows; i += blockDim.x) { - // Detect boundary: either at the end, or value changed - bool at_boundary = (i == n_rows) || (sv[i] != sv[i - 1]); + double local_tie_sum = 0.0; - if (at_boundary) { - // Found end of tie group at position i-1 - // Binary search for start of this tie group - double val = sv[i - 1]; - int lo = 0, hi = i - 1; + int i = my_start; + while (i < my_end) { + double val = sv[i]; + + int tie_local_end = i + 1; + while (tie_local_end < my_end && sv[tie_local_end] == val) + ++tie_local_end; + + int tie_global_start = i; + if (i == my_start && i > 0 && sv[i - 1] == val) { + int lo = 0, hi = i; while (lo < hi) { int mid = (lo + hi) / 2; - if (sv[mid] < val) { + if (sv[mid] < val) lo = mid + 1; - } else { + else hi = mid; - } } - int tie_count = i - lo; + tie_global_start = lo; + } - // t^3 - t for this tie group - double t = (double)tie_count; - local_sum += t * t * t - t; + int tie_global_end = tie_local_end; + if (tie_local_end == my_end && tie_local_end < n_rows && + sv[tie_local_end] == val) { + int lo = tie_local_end, hi = n_rows - 1; + while (lo < hi) { + int mid = (lo + hi + 1) / 2; + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; + } + tie_global_end = lo + 1; } - } - // Warp-level reduction using shuffle -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_down_sync(0xffffffff, local_sum, offset); - } + int total_tie = tie_global_end - tie_global_start; + double avg_rank = (double)(tie_global_start + tie_global_end + 1) / 2.0; + + for (int j = i; j < tie_local_end; ++j) { + int grp = group_codes[si[j]]; + if (grp < n_groups) { + atomicAdd(&grp_sums[grp], avg_rank); + if (compute_stats) { + double v = (double)sv[j]; + atomicAdd(&s_sum[grp], v); + atomicAdd(&s_sq[grp], v * v); + if (v != 0.0) atomicAdd(&s_nnz[grp], 1.0); + } + } + } - // Cross-warp reduction using small shared memory - __shared__ double warp_sums[32]; - int lane = tid & 31; - int warp_id = tid >> 5; + if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { + double t = (double)total_tie; + local_tie_sum += t * t * t - t; + } - if (lane == 0) { - warp_sums[warp_id] = local_sum; + i = tie_local_end; } + __syncthreads(); - // Final reduction in first warp - // Note: blockDim.x must be a multiple of 32 for correct warp reduction - if (tid < 32) { - double val = (tid < (blockDim.x >> 5)) ? warp_sums[tid] : 0.0; -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - val += __shfl_down_sync(0xffffffff, val, offset); - } - if (tid == 0) { - double n = (double)n_rows; - double denom = n * n * n - n; - if (denom > 0) { - correction[col] = 1.0 - val / denom; - } else { - correction[col] = 1.0; - } + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * n_cols + col] = grp_sums[g]; + if (compute_stats) { + group_sums[(size_t)g * n_cols + col] = s_sum[g]; + group_sq_sums[(size_t)g * n_cols + col] = s_sq[g]; + group_nnz[(size_t)g * n_cols + col] = s_nnz[g]; } } -} - -/** - * Kernel to compute average ranks for each column. - * Uses scipy.stats.rankdata 'average' method: ties get the average of the ranks - * they would span. - * - * Each block handles one column. Assumes input is sorted column-wise (F-order). - */ -__global__ void average_rank_kernel(const double* __restrict__ sorted_vals, - const int* __restrict__ sorter, - double* __restrict__ ranks, - const int n_rows, const int n_cols) { - // Each thread block handles one column - int col = blockIdx.x; - if (col >= n_cols) return; - // Pointers to this column's data - const double* sv = sorted_vals + (size_t)col * n_rows; - const int* si = sorter + (size_t)col * n_rows; - double* rk = ranks + (size_t)col * n_rows; - - // Each thread processes multiple rows - for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { - double val = sv[i]; - - // Binary search for tie_start (first element equal to val) - int lo = 0, hi = i; - while (lo < hi) { - int mid = (lo + hi) / 2; - if (sv[mid] < val) { - lo = mid + 1; - } else { - hi = mid; - } - } - int tie_start = lo; - - // Binary search for tie_end (last element equal to val) - lo = i; - hi = n_rows - 1; - while (lo < hi) { - int mid = (lo + hi + 1) / 2; - if (sv[mid] > val) { - hi = mid - 1; - } else { - lo = mid; + if (compute_tie_corr) { + double* warp_buf = smem + n_groups; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_tie_sum += __shfl_down_sync(0xffffffff, local_tie_sum, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = local_tie_sum; + __syncthreads(); + if (threadIdx.x < 32) { + double val = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - val / denom) : 1.0; } } - int tie_end = lo; - - // Average rank for ties: (start + end + 2) / 2 (1-based ranks) - double avg_rank = (double)(tie_start + tie_end + 2) / 2.0; - - // Write rank to original position - rk[si[i]] = avg_rank; } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh new file mode 100644 index 00000000..d1583500 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -0,0 +1,235 @@ +#pragma once + +#include + +// ============================================================================ +// CSR → dense F-order extraction +// ============================================================================ + +__global__ void csr_extract_dense_kernel(const double* __restrict__ data, + const int* __restrict__ indices, + const int* __restrict__ indptr, + const int* __restrict__ row_ids, + double* __restrict__ out, int n_target, + int col_start, int col_stop) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_target) return; + + int row = row_ids[tid]; + int rs = indptr[row]; + int re = indptr[row + 1]; + + int lo = rs, hi = re; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + + for (int p = lo; p < re; ++p) { + int c = indices[p]; + if (c >= col_stop) break; + out[(long long)(c - col_start) * n_target + tid] = data[p]; + } +} + +// ============================================================================ +// Batched rank sums — pre-sorted (binary search, no shared memory sort) +// Used by the OVO streaming pipeline in wilcoxon_streaming.cu. +// ============================================================================ + +__global__ void batched_rank_sums_presorted_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_sorted, + const int* __restrict__ grp_offsets, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + if (n_grp == 0) { + if (threadIdx.x == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + const float* ref_col = ref_sorted + (long long)col * n_ref; + const float* grp_col = grp_sorted + (long long)col * n_all_grp + g_start; + + double local_sum = 0.0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + double v = grp_col[i]; + int lo, hi; + + lo = 0; + hi = n_ref; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + lo = n_lt_ref; + hi = n_ref; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + lo = 0; + hi = n_grp; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (grp_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_grp = lo; + lo = n_lt_grp; + hi = n_grp; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + } + + __shared__ double warp_buf[32]; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_sum += __shfl_down_sync(0xffffffff, local_sum, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = local_sum; + __syncthreads(); + if (threadIdx.x < 32) { + double val = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = val; + } + + if (!compute_tie_corr) return; + __syncthreads(); + + if (threadIdx.x == 0) { + int ri = 0, gi = 0; + double tie_sum = 0.0; + while (ri < n_ref || gi < n_grp) { + double v; + if (gi >= n_grp) + v = ref_col[ri]; + else if (ri >= n_ref) + v = grp_col[gi]; + else + v = (ref_col[ri] <= grp_col[gi]) ? ref_col[ri] : grp_col[gi]; + int cnt = 0; + while (ri < n_ref && ref_col[ri] == v) { + ++ri; + ++cnt; + } + while (gi < n_grp && grp_col[gi] == v) { + ++gi; + ++cnt; + } + if (cnt > 1) { + double t = (double)cnt; + tie_sum += t * t * t - t; + } + } + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + tie_corr[grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +// ============================================================================ +// Grouped statistics: sum, sum-of-squares, nnz per group +// ============================================================================ + +__global__ void grouped_stats_kernel( + const double* __restrict__ data, // F-order (n_all_rows, n_cols) + const int* __restrict__ grp_offsets, // (n_groups + 1,) + double* __restrict__ sums, // (n_groups, n_cols) row-major + double* __restrict__ sq_sums, // (n_groups, n_cols) row-major + double* __restrict__ nnz_counts, // (n_groups, n_cols) row-major + int n_all_rows, int n_cols, int n_groups, bool compute_nnz) { + int col = blockIdx.x; + if (col >= n_cols) return; + + extern __shared__ double smem[]; + double* s_sum = smem; + double* s_sq = smem + n_groups; + double* s_nnz = smem + 2 * n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + s_sum[g] = 0.0; + s_sq[g] = 0.0; + s_nnz[g] = 0.0; + } + __syncthreads(); + + const double* col_data = data + (long long)col * n_all_rows; + + for (int g = 0; g < n_groups; g++) { + int g_start = grp_offsets[g]; + int g_end = grp_offsets[g + 1]; + + double local_sum = 0.0; + double local_sq = 0.0; + double local_nnz = 0.0; + + for (int i = g_start + threadIdx.x; i < g_end; i += blockDim.x) { + double v = col_data[i]; + local_sum += v; + local_sq += v * v; + if (compute_nnz && v != 0.0) local_nnz += 1.0; + } + +#pragma unroll + for (int off = 16; off > 0; off >>= 1) { + local_sum += __shfl_down_sync(0xffffffff, local_sum, off); + local_sq += __shfl_down_sync(0xffffffff, local_sq, off); + if (compute_nnz) + local_nnz += __shfl_down_sync(0xffffffff, local_nnz, off); + } + + if ((threadIdx.x & 31) == 0) { + atomicAdd(&s_sum[g], local_sum); + atomicAdd(&s_sq[g], local_sq); + if (compute_nnz) atomicAdd(&s_nnz[g], local_nnz); + } + __syncthreads(); + } + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + sums[(long long)g * n_cols + col] = s_sum[g]; + sq_sums[(long long)g * n_cols + col] = s_sq[g]; + if (compute_nnz) nnz_counts[(long long)g * n_cols + col] = s_nnz[g]; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index d25f7d0f..e511c895 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -1,11 +1,12 @@ -#include -#include "../nb_types.h" +#include + +#include -#include "kernels_wilcoxon.cuh" +#include "../nb_types.h" +#include "kernels_wilcoxon_ovo.cuh" using namespace nb::literals; -// Constants for kernel launch configuration constexpr int WARP_SIZE = 32; constexpr int MAX_THREADS_PER_BLOCK = 512; @@ -14,57 +15,76 @@ static inline int round_up_to_warp(int n) { return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; } -static inline void launch_tie_correction(const double* sorted_vals, - double* correction, int n_rows, - int n_cols, cudaStream_t stream) { - int threads_per_block = round_up_to_warp(n_rows); - dim3 block(threads_per_block); - dim3 grid(n_cols); - tie_correction_kernel<<>>(sorted_vals, correction, - n_rows, n_cols); - CUDA_CHECK_LAST_ERROR(tie_correction_kernel); -} - -static inline void launch_average_rank(const double* sorted_vals, - const int* sorter, double* ranks, - int n_rows, int n_cols, - cudaStream_t stream) { - int threads_per_block = round_up_to_warp(n_rows); - dim3 block(threads_per_block); - dim3 grid(n_cols); - average_rank_kernel<<>>(sorted_vals, sorter, ranks, - n_rows, n_cols); - CUDA_CHECK_LAST_ERROR(average_rank_kernel); +static size_t get_seg_sort_temp_bytes(int n_items, int n_segments) { + size_t bytes = 0; + auto* dk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys(nullptr, bytes, dk, dk, n_items, + n_segments, doff, doff + 1, 0, 32); + return bytes; } template void register_bindings(nb::module_& m) { m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; - // Tie correction kernel + m.def("get_seg_sort_temp_bytes", &get_seg_sort_temp_bytes, "n_items"_a, + "n_segments"_a); + m.def( - "tie_correction", - [](gpu_array_f sorted_vals, - gpu_array correction, int n_rows, int n_cols, + "segmented_sort", + [](gpu_array_c keys_in, + gpu_array_c keys_out, + gpu_array_c offsets, + gpu_array_c cub_temp, int n_items, int n_segments, std::uintptr_t stream) { - launch_tie_correction(sorted_vals.data(), correction.data(), n_rows, - n_cols, (cudaStream_t)stream); + size_t temp_bytes = cub_temp.size(); + cub::DeviceSegmentedRadixSort::SortKeys( + cub_temp.data(), temp_bytes, keys_in.data(), keys_out.data(), + n_items, n_segments, offsets.data(), offsets.data() + 1, 0, 32, + (cudaStream_t)stream); + CUDA_CHECK_LAST_ERROR(DeviceSegmentedRadixSort); }, - "sorted_vals"_a, "correction"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, - "stream"_a = 0); + "keys_in"_a, "keys_out"_a, "offsets"_a, "cub_temp"_a, nb::kw_only(), + "n_items"_a, "n_segments"_a, "stream"_a = 0); - // Average rank kernel m.def( - "average_rank", - [](gpu_array_f sorted_vals, - gpu_array_f sorter, - gpu_array_f ranks, int n_rows, int n_cols, - std::uintptr_t stream) { - launch_average_rank(sorted_vals.data(), sorter.data(), ranks.data(), - n_rows, n_cols, (cudaStream_t)stream); + "csr_extract_dense", + [](gpu_array_c data, + gpu_array_c indices, + gpu_array_c indptr, + gpu_array_c row_ids, + gpu_array_f out, int n_target, int col_start, + int col_stop, std::uintptr_t stream) { + int tpb = round_up_to_warp(n_target); + int blocks = (n_target + tpb - 1) / tpb; + csr_extract_dense_kernel<<>>( + data.data(), indices.data(), indptr.data(), row_ids.data(), + out.data(), n_target, col_start, col_stop); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + }, + "data"_a, "indices"_a, "indptr"_a, "row_ids"_a, "out"_a, nb::kw_only(), + "n_target"_a, "col_start"_a, "col_stop"_a, "stream"_a = 0); + + m.def( + "grouped_stats", + [](gpu_array_f data, + gpu_array_c grp_offsets, + gpu_array_c sums, + gpu_array_c sq_sums, + gpu_array_c nnz_counts, int n_all_rows, int n_cols, + int n_groups, bool compute_nnz, std::uintptr_t stream) { + constexpr int THREADS = 256; + int smem = 3 * n_groups * sizeof(double); + grouped_stats_kernel<<>>( + data.data(), grp_offsets.data(), sums.data(), sq_sums.data(), + nnz_counts.data(), n_all_rows, n_cols, n_groups, compute_nnz); + CUDA_CHECK_LAST_ERROR(grouped_stats_kernel); }, - "sorted_vals"_a, "sorter"_a, "ranks"_a, nb::kw_only(), "n_rows"_a, - "n_cols"_a, "stream"_a = 0); + "data"_a, "grp_offsets"_a, "sums"_a, "sq_sums"_a, "nnz_counts"_a, + nb::kw_only(), "n_all_rows"_a, "n_cols"_a, "n_groups"_a, + "compute_nnz"_a, "stream"_a = 0); } NB_MODULE(_wilcoxon_cuda, m) { diff --git a/src/rapids_singlecell/_cuda/wilcoxon_streaming/wilcoxon_streaming.cu b/src/rapids_singlecell/_cuda/wilcoxon_streaming/wilcoxon_streaming.cu new file mode 100644 index 00000000..2b0a1798 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon_streaming/wilcoxon_streaming.cu @@ -0,0 +1,1283 @@ +#include +#include + +#include + +#include "../nb_types.h" +#include "../wilcoxon/kernels_wilcoxon.cuh" +#include "../wilcoxon/kernels_wilcoxon_ovo.cuh" + +using namespace nb::literals; + +constexpr int WARP_SIZE = 32; +constexpr int MAX_THREADS_PER_BLOCK = 512; +constexpr int N_STREAMS = 4; +constexpr int SUB_BATCH_COLS = 32; +constexpr int BEGIN_BIT = 0; +constexpr int END_BIT = 32; + +static inline int round_up_to_warp(int n) { + int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; +} + +/** + * Extract dense F-order float32 block from CSR. + * All rows, column range [col_start, col_stop). + * One thread per row, binary search for col_start. + * Output must be pre-zeroed. + */ +__global__ void csr_extract_f32_kernel(const float* __restrict__ data, + const int* __restrict__ indices, + const int* __restrict__ indptr, + float* __restrict__ out, int n_rows, + int col_start, int col_stop) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + + int rs = indptr[row]; + int re = indptr[row + 1]; + + // Binary search for col_start + int lo = rs, hi = re; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + + int n_cols = col_stop - col_start; + for (int p = lo; p < re; ++p) { + int c = indices[p]; + if (c >= col_stop) break; + out[(long long)(c - col_start) * n_rows + row] = data[p]; + } +} + +/** + * Extract dense F-order float32 block from CSC. + * Column range [col_start, col_stop). + * One block per column, threads scatter nonzeros. + * Output must be pre-zeroed. + */ +__global__ void csc_extract_f32_kernel(const float* __restrict__ data, + const int* __restrict__ indices, + const int* __restrict__ indptr, + float* __restrict__ out, int n_rows, + int col_start) { + int col_local = blockIdx.x; + int col = col_start + col_local; + + int start = indptr[col]; + int end = indptr[col + 1]; + + for (int p = start + threadIdx.x; p < end; p += blockDim.x) { + int row = indices[p]; + out[(long long)col_local * n_rows + row] = data[p]; + } +} + +/** + * Fill sort values with row indices [0,1,...,n_rows-1] per column. + * Grid: (n_cols,), block: 256 threads. + */ +__global__ void fill_row_indices_kernel(int* __restrict__ vals, int n_rows, + int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + int* out = vals + (long long)col * n_rows; + for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { + out[i] = i; + } +} + +/** + * Streaming OVR pipeline. + * + * Takes a dense F-order float32 block (n_rows, n_cols) + int32 group_codes, + * splits columns into sub-batches across multiple CUDA streams, and for each: + * 1. CUB SortPairs (float32 keys + int32 row indices) + * 2. Fused rank_sums_from_sorted_kernel + * + * Output: rank_sums (n_groups, n_cols) + tie_corr (n_cols), both float64. + */ +static void ovr_streaming_impl(const float* block, const int* group_codes, + double* rank_sums, double* tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_items = (size_t)n_rows * sub_batch_cols; + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + // Create streams + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Per-stream buffers (allocated once, reused across sub-batches) + struct StreamBuf { + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + }; + + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + cudaMalloc(&bufs[s].keys_out, sub_items * sizeof(float)); + cudaMalloc(&bufs[s].vals_in, sub_items * sizeof(int)); + cudaMalloc(&bufs[s].vals_out, sub_items * sizeof(int)); + cudaMalloc(&bufs[s].seg_offsets, (sub_batch_cols + 1) * sizeof(int)); + cudaMalloc(&bufs[s].cub_temp, cub_temp_bytes); + } + + int tpb_rank = round_up_to_warp(n_rows); + int smem_rank = (4 * n_groups + 32) * sizeof(double); + + // Allocate sub-batch output buffers per stream + std::vector sub_rank_sums(n_streams); + std::vector sub_tie_corr(n_streams); + for (int s = 0; s < n_streams; s++) { + cudaMalloc(&sub_rank_sums[s], + (size_t)n_groups * sub_batch_cols * sizeof(double)); + cudaMalloc(&sub_tie_corr[s], sub_batch_cols * sizeof(double)); + } + + // Process sub-batches round-robin across streams + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = n_rows * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // Fill segment offsets: [0, n_rows, 2*n_rows, ...] + { + int* h_off = new int[sb_cols + 1]; + for (int i = 0; i <= sb_cols; i++) h_off[i] = i * n_rows; + cudaMemcpyAsync(buf.seg_offsets, h_off, (sb_cols + 1) * sizeof(int), + cudaMemcpyHostToDevice, stream); + delete[] h_off; + } + + // Fill row indices + fill_row_indices_kernel<<>>(buf.vals_in, + n_rows, sb_cols); + + // Sort: keys = block columns [col, col+sb_cols), already F-order + const float* keys_in = block + (long long)col * n_rows; + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, keys_in, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + + // Fused rank sums into sub-batch buffer + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, group_codes, sub_rank_sums[s], + sub_tie_corr[s], nullptr, nullptr, nullptr, n_rows, sb_cols, + n_groups, compute_tie_corr, false); + + // Copy sub-batch results to global output (row-major scatter) + // rank_sums is (n_groups, n_cols) row-major: group g, col c → + // [g*n_cols+c] sub output is (n_groups, sb_cols): group g, local col lc + // → [g*sb_cols+lc] + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + sub_rank_sums[s], sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, sub_tie_corr[s], + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + batch_idx++; + } + + // Sync all streams + for (int s = 0; s < n_streams; s++) { + cudaStreamSynchronize(streams[s]); + } + + // Cleanup + for (int s = 0; s < n_streams; s++) { + cudaFree(bufs[s].keys_out); + cudaFree(bufs[s].vals_in); + cudaFree(bufs[s].vals_out); + cudaFree(bufs[s].seg_offsets); + cudaFree(bufs[s].cub_temp); + cudaFree(sub_rank_sums[s]); + cudaFree(sub_tie_corr[s]); + cudaStreamDestroy(streams[s]); + } +} + +/** + * CSR-direct OVR streaming pipeline. + * + * Takes GPU CSR arrays directly — no CSR→CSC conversion needed. + * For each sub-batch: extract dense columns from CSR → sort → rank. + * Everything on one GPU with multi-stream overlap. + */ +static void ovr_streaming_csr_impl( + const float* csr_data, const int* csr_indices, const int* csr_indptr, + const int* group_codes, double* rank_sums, double* tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_items = (size_t)n_rows * sub_batch_cols; + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + struct StreamBuf { + float* dense; // extracted dense sub-batch + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + }; + + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + cudaMalloc(&bufs[s].dense, sub_items * sizeof(float)); + cudaMalloc(&bufs[s].keys_out, sub_items * sizeof(float)); + cudaMalloc(&bufs[s].vals_in, sub_items * sizeof(int)); + cudaMalloc(&bufs[s].vals_out, sub_items * sizeof(int)); + cudaMalloc(&bufs[s].seg_offsets, (sub_batch_cols + 1) * sizeof(int)); + cudaMalloc(&bufs[s].cub_temp, cub_temp_bytes); + } + + std::vector sub_rank_sums(n_streams); + std::vector sub_tie_corr(n_streams); + for (int s = 0; s < n_streams; s++) { + cudaMalloc(&sub_rank_sums[s], + (size_t)n_groups * sub_batch_cols * sizeof(double)); + cudaMalloc(&sub_tie_corr[s], sub_batch_cols * sizeof(double)); + } + + int tpb_rank = round_up_to_warp(n_rows); + int smem_rank = (4 * n_groups + 32) * sizeof(double); + int tpb_extract = round_up_to_warp(n_rows); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = n_rows * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // Zero dense buffer + cudaMemsetAsync(buf.dense, 0, sb_items * sizeof(float), stream); + + // Extract dense columns from CSR + int extract_blocks = (n_rows + tpb_extract - 1) / tpb_extract; + csr_extract_f32_kernel<<>>( + csr_data, csr_indices, csr_indptr, buf.dense, n_rows, col, + col + sb_cols); + + // Fill segment offsets + row indices + { + int* h_off = new int[sb_cols + 1]; + for (int i = 0; i <= sb_cols; i++) h_off[i] = i * n_rows; + cudaMemcpyAsync(buf.seg_offsets, h_off, (sb_cols + 1) * sizeof(int), + cudaMemcpyHostToDevice, stream); + delete[] h_off; + } + fill_row_indices_kernel<<>>(buf.vals_in, + n_rows, sb_cols); + + // Sort + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.dense, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + + // Fused rank sums + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, group_codes, sub_rank_sums[s], + sub_tie_corr[s], nullptr, nullptr, nullptr, n_rows, sb_cols, + n_groups, compute_tie_corr, false); + + // Scatter to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + sub_rank_sums[s], sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, sub_tie_corr[s], + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + + for (int s = 0; s < n_streams; s++) { + cudaFree(bufs[s].dense); + cudaFree(bufs[s].keys_out); + cudaFree(bufs[s].vals_in); + cudaFree(bufs[s].vals_out); + cudaFree(bufs[s].seg_offsets); + cudaFree(bufs[s].cub_temp); + cudaFree(sub_rank_sums[s]); + cudaFree(sub_tie_corr[s]); + cudaStreamDestroy(streams[s]); + } +} + +/** + * CSC-direct OVR streaming pipeline. + * + * Takes GPU CSC arrays directly — no format conversion needed. + * For each sub-batch: extract dense columns from CSC → sort → rank. + * CSC extraction is a simple scatter (no binary search), faster than CSR. + */ +static void ovr_streaming_csc_impl( + const float* csc_data, const int* csc_indices, const int* csc_indptr, + const int* group_codes, double* rank_sums, double* tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_items = (size_t)n_rows * sub_batch_cols; + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + struct StreamBuf { + float* dense; + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + }; + + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + cudaMalloc(&bufs[s].dense, sub_items * sizeof(float)); + cudaMalloc(&bufs[s].keys_out, sub_items * sizeof(float)); + cudaMalloc(&bufs[s].vals_in, sub_items * sizeof(int)); + cudaMalloc(&bufs[s].vals_out, sub_items * sizeof(int)); + cudaMalloc(&bufs[s].seg_offsets, (sub_batch_cols + 1) * sizeof(int)); + cudaMalloc(&bufs[s].cub_temp, cub_temp_bytes); + } + + std::vector sub_rank_sums(n_streams); + std::vector sub_tie_corr(n_streams); + for (int s = 0; s < n_streams; s++) { + cudaMalloc(&sub_rank_sums[s], + (size_t)n_groups * sub_batch_cols * sizeof(double)); + cudaMalloc(&sub_tie_corr[s], sub_batch_cols * sizeof(double)); + } + + int tpb_rank = round_up_to_warp(n_rows); + int smem_rank = (4 * n_groups + 32) * sizeof(double); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = n_rows * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // Zero dense buffer + cudaMemsetAsync(buf.dense, 0, sb_items * sizeof(float), stream); + + // Extract dense columns from CSC — simple scatter, no binary search + csc_extract_f32_kernel<<>>( + csc_data, csc_indices, csc_indptr, buf.dense, n_rows, col); + + // Fill segment offsets + row indices + { + int* h_off = new int[sb_cols + 1]; + for (int i = 0; i <= sb_cols; i++) h_off[i] = i * n_rows; + cudaMemcpyAsync(buf.seg_offsets, h_off, (sb_cols + 1) * sizeof(int), + cudaMemcpyHostToDevice, stream); + delete[] h_off; + } + fill_row_indices_kernel<<>>(buf.vals_in, + n_rows, sb_cols); + + // Sort + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.dense, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + + // Fused rank sums + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, group_codes, sub_rank_sums[s], + sub_tie_corr[s], nullptr, nullptr, nullptr, n_rows, sb_cols, + n_groups, compute_tie_corr, false); + + // Scatter to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + sub_rank_sums[s], sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, sub_tie_corr[s], + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + + for (int s = 0; s < n_streams; s++) { + cudaFree(bufs[s].dense); + cudaFree(bufs[s].keys_out); + cudaFree(bufs[s].vals_in); + cudaFree(bufs[s].vals_out); + cudaFree(bufs[s].seg_offsets); + cudaFree(bufs[s].cub_temp); + cudaFree(sub_rank_sums[s]); + cudaFree(sub_tie_corr[s]); + cudaStreamDestroy(streams[s]); + } +} + +/** + * Host-streaming CSC OVR pipeline. + * + * CSC arrays live on host. Only the sparse data for each sub-batch of + * columns is transferred to GPU, so GPU memory is O(sub_batch * n_rows). + * H2D of sub-batch N+1 overlaps compute of sub-batch N via multi-stream. + */ +static void ovr_streaming_csc_host_impl( + const float* h_data, const int* h_indices, const int* h_indptr, + const int* h_group_codes, double* h_rank_sums, double* h_tie_corr, + int n_rows, int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_items = (size_t)n_rows * sub_batch_cols; + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + // Find max nnz across any sub-batch to size the sparse transfer buffers + size_t max_nnz = 0; + for (int col = 0; col < n_cols; col += sub_batch_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); + if (nnz > max_nnz) max_nnz = nnz; + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Group codes on GPU (transferred once) + int* d_group_codes; + cudaMalloc(&d_group_codes, n_rows * sizeof(int)); + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + + struct StreamBuf { + float* d_sparse_data; // H2D sparse values + int* d_sparse_indices; // H2D sparse row indices + int* d_indptr; // H2D indptr slice (sb_cols + 1) + float* dense; // extracted dense + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + }; + + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + cudaMalloc(&bufs[s].d_sparse_data, max_nnz * sizeof(float)); + cudaMalloc(&bufs[s].d_sparse_indices, max_nnz * sizeof(int)); + cudaMalloc(&bufs[s].d_indptr, (sub_batch_cols + 1) * sizeof(int)); + cudaMalloc(&bufs[s].dense, sub_items * sizeof(float)); + cudaMalloc(&bufs[s].keys_out, sub_items * sizeof(float)); + cudaMalloc(&bufs[s].vals_in, sub_items * sizeof(int)); + cudaMalloc(&bufs[s].vals_out, sub_items * sizeof(int)); + cudaMalloc(&bufs[s].seg_offsets, (sub_batch_cols + 1) * sizeof(int)); + cudaMalloc(&bufs[s].cub_temp, cub_temp_bytes); + cudaMalloc(&bufs[s].d_rank_sums, + (size_t)n_groups * sub_batch_cols * sizeof(double)); + cudaMalloc(&bufs[s].d_tie_corr, sub_batch_cols * sizeof(double)); + } + + int tpb_rank = round_up_to_warp(n_rows); + int smem_rank = (4 * n_groups + 32) * sizeof(double); + + // Pin host memory for async transfers + cudaHostRegister(const_cast(h_data), + (size_t)h_indptr[n_cols] * sizeof(float), 0); + cudaHostRegister(const_cast(h_indices), + (size_t)h_indptr[n_cols] * sizeof(int), 0); + cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), + 0); + cudaHostRegister(h_tie_corr, n_cols * sizeof(double), 0); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = n_rows * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // H2D: transfer sparse data for this column range + int ptr_start = h_indptr[col]; + int ptr_end = h_indptr[col + sb_cols]; + size_t nnz = (size_t)(ptr_end - ptr_start); + cudaMemcpyAsync(buf.d_sparse_data, h_data + ptr_start, + nnz * sizeof(float), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, + nnz * sizeof(int), cudaMemcpyHostToDevice, stream); + + // Transfer adjusted indptr (rebased to 0) + // h_indptr[col..col+sb_cols] - h_indptr[col] + { + int* h_adj = new int[sb_cols + 1]; + for (int i = 0; i <= sb_cols; i++) + h_adj[i] = h_indptr[col + i] - ptr_start; + cudaMemcpyAsync(buf.d_indptr, h_adj, (sb_cols + 1) * sizeof(int), + cudaMemcpyHostToDevice, stream); + delete[] h_adj; + } + + // Zero dense buffer + cudaMemsetAsync(buf.dense, 0, sb_items * sizeof(float), stream); + + // CSC extract from transferred sparse data (col_start=0 because + // indptr is rebased and data/indices are for this sub-batch only) + csc_extract_f32_kernel<<>>( + buf.d_sparse_data, buf.d_sparse_indices, buf.d_indptr, buf.dense, + n_rows, 0); + + // Fill segment offsets + row indices + { + int* h_off = new int[sb_cols + 1]; + for (int i = 0; i <= sb_cols; i++) h_off[i] = i * n_rows; + cudaMemcpyAsync(buf.seg_offsets, h_off, (sb_cols + 1) * sizeof(int), + cudaMemcpyHostToDevice, stream); + delete[] h_off; + } + fill_row_indices_kernel<<>>(buf.vals_in, + n_rows, sb_cols); + + // Sort + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.dense, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + + // Fused rank sums + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, d_group_codes, buf.d_rank_sums, + buf.d_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, + n_groups, compute_tie_corr, false); + + // D2H: scatter results to host output + cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(h_tie_corr + col, buf.d_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToHost, + stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + + cudaHostUnregister(const_cast(h_data)); + cudaHostUnregister(const_cast(h_indices)); + cudaHostUnregister(h_rank_sums); + cudaHostUnregister(h_tie_corr); + + cudaFree(d_group_codes); + for (int s = 0; s < n_streams; s++) { + cudaFree(bufs[s].d_sparse_data); + cudaFree(bufs[s].d_sparse_indices); + cudaFree(bufs[s].d_indptr); + cudaFree(bufs[s].dense); + cudaFree(bufs[s].keys_out); + cudaFree(bufs[s].vals_in); + cudaFree(bufs[s].vals_out); + cudaFree(bufs[s].seg_offsets); + cudaFree(bufs[s].cub_temp); + cudaFree(bufs[s].d_rank_sums); + cudaFree(bufs[s].d_tie_corr); + cudaStreamDestroy(streams[s]); + } +} + +/** + * Host-streaming dense OVR pipeline. + * + * Dense F-order float32 block lives on host. Sub-batches of 64 columns + * are transferred to GPU per stream, so GPU memory is O(sub_batch * n_rows). + */ +static void ovr_streaming_dense_host_impl( + const float* h_block, const int* h_group_codes, double* h_rank_sums, + double* h_tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_items = (size_t)n_rows * sub_batch_cols; + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + int* d_group_codes; + cudaMalloc(&d_group_codes, n_rows * sizeof(int)); + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + + struct StreamBuf { + float* d_block; // H2D dense sub-batch + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + }; + + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + cudaMalloc(&bufs[s].d_block, sub_items * sizeof(float)); + cudaMalloc(&bufs[s].keys_out, sub_items * sizeof(float)); + cudaMalloc(&bufs[s].vals_in, sub_items * sizeof(int)); + cudaMalloc(&bufs[s].vals_out, sub_items * sizeof(int)); + cudaMalloc(&bufs[s].seg_offsets, (sub_batch_cols + 1) * sizeof(int)); + cudaMalloc(&bufs[s].cub_temp, cub_temp_bytes); + cudaMalloc(&bufs[s].d_rank_sums, + (size_t)n_groups * sub_batch_cols * sizeof(double)); + cudaMalloc(&bufs[s].d_tie_corr, sub_batch_cols * sizeof(double)); + } + + int tpb_rank = round_up_to_warp(n_rows); + int smem_rank = (4 * n_groups + 32) * sizeof(double); + + // Pin host memory + cudaHostRegister(const_cast(h_block), + (size_t)n_rows * n_cols * sizeof(float), 0); + cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), + 0); + cudaHostRegister(h_tie_corr, n_cols * sizeof(double), 0); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = n_rows * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // H2D: column sub-batch (F-order → contiguous) + cudaMemcpyAsync(buf.d_block, h_block + (long long)col * n_rows, + sb_items * sizeof(float), cudaMemcpyHostToDevice, + stream); + + // Fill segment offsets + row indices + { + int* h_off = new int[sb_cols + 1]; + for (int i = 0; i <= sb_cols; i++) h_off[i] = i * n_rows; + cudaMemcpyAsync(buf.seg_offsets, h_off, (sb_cols + 1) * sizeof(int), + cudaMemcpyHostToDevice, stream); + delete[] h_off; + } + fill_row_indices_kernel<<>>(buf.vals_in, + n_rows, sb_cols); + + // Sort + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.d_block, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + + // Fused rank sums + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, d_group_codes, buf.d_rank_sums, + buf.d_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, + n_groups, compute_tie_corr, false); + + // D2H: scatter results + cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(h_tie_corr + col, buf.d_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToHost, + stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + + cudaHostUnregister(const_cast(h_block)); + cudaHostUnregister(h_rank_sums); + cudaHostUnregister(h_tie_corr); + + cudaFree(d_group_codes); + for (int s = 0; s < n_streams; s++) { + cudaFree(bufs[s].d_block); + cudaFree(bufs[s].keys_out); + cudaFree(bufs[s].vals_in); + cudaFree(bufs[s].vals_out); + cudaFree(bufs[s].seg_offsets); + cudaFree(bufs[s].cub_temp); + cudaFree(bufs[s].d_rank_sums); + cudaFree(bufs[s].d_tie_corr); + cudaStreamDestroy(streams[s]); + } +} + +/** + * Build segment offsets for CUB segmented sort of group data within a + * sub-batch. offset[c * n_groups + g] = c * n_all_grp + grp_offsets[g]. + * One thread per entry. + */ +__global__ void build_seg_offsets_kernel( + const int* __restrict__ grp_offsets, // (n_groups + 1,) + int* __restrict__ out, // (sb_cols * n_groups + 1,) + int n_all_grp, int n_groups, int sb_cols) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = sb_cols * n_groups + 1; + if (idx >= total) return; + if (idx == sb_cols * n_groups) { + out[idx] = sb_cols * n_all_grp; + } else { + int c = idx / n_groups; + int g = idx % n_groups; + out[idx] = c * n_all_grp + grp_offsets[g]; + } +} + +/** + * Streaming OVO pipeline. + * + * Takes pre-sorted reference (float32 F-order), unsorted group data (float32 + * F-order with group offsets), and produces rank_sums + tie_corr. + * + * For each sub-batch of columns: + * 1. CUB segmented sort-keys of group data (one segment per group per col) + * 2. batched_rank_sums_presorted_kernel (binary search in sorted ref) + */ +static void ovo_streaming_impl(const float* ref_sorted, const float* grp_data, + const int* grp_offsets, double* rank_sums, + double* tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + int max_n_seg = n_groups * sub_batch_cols; + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys(nullptr, cub_temp_bytes, fk, fk, + (int)sub_grp_items, max_n_seg, + doff, doff + 1, 0, 32); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + struct StreamBuf { + float* grp_sorted; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + }; + + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + cudaMalloc(&bufs[s].grp_sorted, sub_grp_items * sizeof(float)); + cudaMalloc(&bufs[s].seg_offsets, (max_n_seg + 1) * sizeof(int)); + cudaMalloc(&bufs[s].cub_temp, cub_temp_bytes); + cudaMalloc(&bufs[s].sub_rank_sums, + (size_t)n_groups * sub_batch_cols * sizeof(double)); + cudaMalloc(&bufs[s].sub_tie_corr, + (size_t)n_groups * sub_batch_cols * sizeof(double)); + } + + // Import the presorted kernel from the OVO header + // (included via kernels_wilcoxon_ovo.cuh) + int tpb_rank = round_up_to_warp(std::min(n_all_grp, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_n_seg = n_groups * sb_cols; + int sb_grp_items = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // Build segment offsets on device + { + int total = sb_n_seg + 1; + int blk = (total + 255) / 256; + build_seg_offsets_kernel<<>>( + grp_offsets, buf.seg_offsets, n_all_grp, n_groups, sb_cols); + } + + // Sort group data for this sub-batch + const float* grp_in = grp_data + (long long)col * n_all_grp; + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, grp_in, buf.grp_sorted, sb_grp_items, sb_n_seg, + buf.seg_offsets, buf.seg_offsets + 1, 0, 32, stream); + + // Rank sums: binary search sorted ref for each group element + const float* ref_sub = ref_sorted + (long long)col * n_ref; + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr); + + // Scatter sub-batch results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + + for (int s = 0; s < n_streams; s++) { + cudaFree(bufs[s].grp_sorted); + cudaFree(bufs[s].seg_offsets); + cudaFree(bufs[s].cub_temp); + cudaFree(bufs[s].sub_rank_sums); + cudaFree(bufs[s].sub_tie_corr); + cudaStreamDestroy(streams[s]); + } +} + +/** + * Multi-GPU OVO streaming pipeline with host data. + * + * Ref block is sorted on GPU 0, then P2P copied to other GPUs. + * Group data is streamed from host to each GPU's streams. + */ +static void ovo_streaming_multigpu_impl( + const float* h_ref_sorted, const float* h_grp_data, + const int* h_grp_offsets, double* h_rank_sums, double* h_tie_corr, + int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols, const int* h_device_ids, int n_devices) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + constexpr int STREAMS_PER_GPU = 2; + + // CUB temp for segmented sort of group data + int max_n_seg = n_groups * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys(nullptr, cub_temp_bytes, fk, fk, + (int)sub_grp_items, max_n_seg, + doff, doff + 1, 0, 32); + } + + int tpb = round_up_to_warp(std::min(n_all_grp, MAX_THREADS_PER_BLOCK)); + + struct GpuCtx { + int device_id; + cudaStream_t streams[2]; + float* d_ref_sorted; // full ref, copied once + int* d_grp_offsets; + struct { + float* d_grp_data; + float* d_grp_sorted; + int* d_seg_offsets; + uint8_t* d_cub_temp; + double* d_rank_sums; + double* d_tie_corr; + } buf[2]; + }; + + std::vector gpus(n_devices); + + // Phase 1: allocate + upload ref + offsets to each GPU + for (int d = 0; d < n_devices; d++) { + auto& g = gpus[d]; + g.device_id = h_device_ids[d]; + cudaSetDevice(g.device_id); + + for (int s = 0; s < STREAMS_PER_GPU; s++) + cudaStreamCreate(&g.streams[s]); + + size_t ref_size = (size_t)n_ref * n_cols; + cudaMalloc(&g.d_ref_sorted, ref_size * sizeof(float)); + cudaMemcpyAsync(g.d_ref_sorted, h_ref_sorted, ref_size * sizeof(float), + cudaMemcpyHostToDevice, g.streams[0]); + + cudaMalloc(&g.d_grp_offsets, (n_groups + 1) * sizeof(int)); + cudaMemcpyAsync(g.d_grp_offsets, h_grp_offsets, + (n_groups + 1) * sizeof(int), cudaMemcpyHostToDevice, + g.streams[0]); + + for (int s = 0; s < STREAMS_PER_GPU; s++) { + cudaMalloc(&g.buf[s].d_grp_data, sub_grp_items * sizeof(float)); + cudaMalloc(&g.buf[s].d_grp_sorted, sub_grp_items * sizeof(float)); + cudaMalloc(&g.buf[s].d_seg_offsets, (max_n_seg + 1) * sizeof(int)); + cudaMalloc(&g.buf[s].d_cub_temp, cub_temp_bytes); + cudaMalloc(&g.buf[s].d_rank_sums, + (size_t)n_groups * sub_batch_cols * sizeof(double)); + cudaMalloc(&g.buf[s].d_tie_corr, + (size_t)n_groups * sub_batch_cols * sizeof(double)); + } + } + + // Phase 2: process sub-batches + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_n_seg = n_groups * sb_cols; + int sb_grp_items = n_all_grp * sb_cols; + + int d = (batch_idx / STREAMS_PER_GPU) % n_devices; + int s = batch_idx % STREAMS_PER_GPU; + auto& g = gpus[d]; + auto stream = g.streams[s]; + auto& buf = g.buf[s]; + + cudaSetDevice(g.device_id); + + // H2D: group data sub-batch + cudaMemcpyAsync(buf.d_grp_data, h_grp_data + (long long)col * n_all_grp, + sb_grp_items * sizeof(float), cudaMemcpyHostToDevice, + stream); + + // Build segment offsets on device + { + int total = sb_n_seg + 1; + int blk = (total + 255) / 256; + build_seg_offsets_kernel<<>>( + g.d_grp_offsets, buf.d_seg_offsets, n_all_grp, n_groups, + sb_cols); + } + + // Sort group data + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.d_cub_temp, temp, buf.d_grp_data, buf.d_grp_sorted, + sb_grp_items, sb_n_seg, buf.d_seg_offsets, buf.d_seg_offsets + 1, 0, + 32, stream); + + // Rank sums + const float* ref_sub = g.d_ref_sorted + (long long)col * n_ref; + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.d_grp_sorted, g.d_grp_offsets, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr); + + // D2H: scatter results + cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(h_tie_corr + col, n_cols * sizeof(double), + buf.d_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, stream); + } + + col += sb_cols; + batch_idx++; + } + + // Phase 3: sync + cleanup + for (int d = 0; d < n_devices; d++) { + cudaSetDevice(gpus[d].device_id); + for (int s = 0; s < STREAMS_PER_GPU; s++) + cudaStreamSynchronize(gpus[d].streams[s]); + } + for (int d = 0; d < n_devices; d++) { + cudaSetDevice(gpus[d].device_id); + cudaFree(gpus[d].d_ref_sorted); + cudaFree(gpus[d].d_grp_offsets); + for (int s = 0; s < STREAMS_PER_GPU; s++) { + cudaFree(gpus[d].buf[s].d_grp_data); + cudaFree(gpus[d].buf[s].d_grp_sorted); + cudaFree(gpus[d].buf[s].d_seg_offsets); + cudaFree(gpus[d].buf[s].d_cub_temp); + cudaFree(gpus[d].buf[s].d_rank_sums); + cudaFree(gpus[d].buf[s].d_tie_corr); + cudaStreamDestroy(gpus[d].streams[s]); + } + } + cudaSetDevice(h_device_ids[0]); +} + +// ============================================================================ +// Nanobind module +// ============================================================================ + +template +void register_bindings(nb::module_& m) { + m.doc() = "Streaming Wilcoxon pipeline with multi-stream overlap"; + + m.def( + "ovr_streaming", + [](gpu_array_f block, + gpu_array_c group_codes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_streaming_impl(block.data(), group_codes.data(), + rank_sums.data(), tie_corr.data(), n_rows, + n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "block"_a, "group_codes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovr_streaming_csr", + [](gpu_array_c csr_data, + gpu_array_c csr_indices, + gpu_array_c csr_indptr, + gpu_array_c group_codes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_streaming_csr_impl( + csr_data.data(), csr_indices.data(), csr_indptr.data(), + group_codes.data(), rank_sums.data(), tie_corr.data(), n_rows, + n_cols, n_groups, compute_tie_corr, sub_batch_cols); + }, + "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "group_codes"_a, + "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovr_streaming_csc", + [](gpu_array_c csc_data, + gpu_array_c csc_indices, + gpu_array_c csc_indptr, + gpu_array_c group_codes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_streaming_csc_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + group_codes.data(), rank_sums.data(), tie_corr.data(), n_rows, + n_cols, n_groups, compute_tie_corr, sub_batch_cols); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "group_codes"_a, + "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovo_streaming", + [](gpu_array_f ref_sorted, + gpu_array_f grp_data, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_impl(ref_sorted.data(), grp_data.data(), + grp_offsets.data(), rank_sums.data(), + tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "ref_sorted"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, + "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); +} + +NB_MODULE(_wilcoxon_streaming_cuda, m) { + REGISTER_GPU_BINDINGS(register_bindings, m); + + m.def( + "ovr_streaming_csc_host", + [](host_array h_data, host_array h_indices, + host_array h_indptr, host_array h_group_codes, + host_array_2d h_rank_sums, host_array h_tie_corr, + int n_rows, int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovr_streaming_csc_host_impl( + h_data.data(), h_indices.data(), h_indptr.data(), + h_group_codes.data(), h_rank_sums.data(), h_tie_corr.data(), + n_rows, n_cols, n_groups, compute_tie_corr, sub_batch_cols); + }, + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, + "h_rank_sums"_a, "h_tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovr_streaming_dense_host", + [](host_array_2d h_block, + host_array h_group_codes, + host_array_2d h_rank_sums, host_array h_tie_corr, + int n_rows, int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovr_streaming_dense_host_impl(h_block.data(), h_group_codes.data(), + h_rank_sums.data(), h_tie_corr.data(), + n_rows, n_cols, n_groups, + compute_tie_corr, sub_batch_cols); + }, + "h_block"_a, "h_group_codes"_a, "h_rank_sums"_a, "h_tie_corr"_a, + nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovo_streaming_multigpu", + [](host_array_2d h_ref_sorted, + host_array_2d h_grp_data, + host_array h_grp_offsets, + host_array_2d h_rank_sums, host_array_2d h_tie_corr, + int n_ref, int n_all_grp, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols, + host_array device_ids) { + // Pin host arrays + size_t ref_bytes = (size_t)n_ref * n_cols * sizeof(float); + size_t grp_bytes = (size_t)n_all_grp * n_cols * sizeof(float); + size_t rs_bytes = (size_t)n_groups * n_cols * sizeof(double); + cudaHostRegister(const_cast(h_ref_sorted.data()), ref_bytes, + 0); + cudaHostRegister(const_cast(h_grp_data.data()), grp_bytes, + 0); + cudaHostRegister(const_cast(h_grp_offsets.data()), + (n_groups + 1) * sizeof(int), 0); + cudaHostRegister(h_rank_sums.data(), rs_bytes, 0); + cudaHostRegister(h_tie_corr.data(), + (size_t)n_groups * n_cols * sizeof(double), 0); + + ovo_streaming_multigpu_impl( + h_ref_sorted.data(), h_grp_data.data(), h_grp_offsets.data(), + h_rank_sums.data(), h_tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols, device_ids.data(), + static_cast(device_ids.size())); + + cudaHostUnregister(const_cast(h_ref_sorted.data())); + cudaHostUnregister(const_cast(h_grp_data.data())); + cudaHostUnregister(const_cast(h_grp_offsets.data())); + cudaHostUnregister(h_rank_sums.data()); + cudaHostUnregister(h_tie_corr.data()); + }, + "h_ref_sorted"_a, "h_grp_data"_a, "h_grp_offsets"_a, "h_rank_sums"_a, + "h_tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS, + "device_ids"_a); +} diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index c4f2c601..9dea8f11 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -102,11 +102,14 @@ def _select_top_n(scores: NDArray, n_top: int) -> NDArray: return global_indices +DEFAULT_CHUNK_SIZE = 512 + + def _choose_chunk_size(requested: int | None) -> int: """Choose chunk size for gene processing.""" if requested is not None: return int(requested) - return 128 + return DEFAULT_CHUNK_SIZE def _csc_columns_to_gpu(X_csc, start: int, stop: int, n_rows: int) -> cp.ndarray: @@ -124,22 +127,22 @@ def _csc_columns_to_gpu(X_csc, start: int, stop: int, n_rows: int) -> cp.ndarray csc_chunk = cpsp.csc_matrix( (chunk_data, chunk_indices, chunk_indptr), shape=(n_rows, stop - start) ) - return _sparse_to_dense(csc_chunk, order="F").astype(cp.float64) + return _sparse_to_dense(csc_chunk, order="F") def _get_column_block(X, start: int, stop: int) -> cp.ndarray: - """Extract a column block as a dense F-order float64 CuPy array.""" + """Extract a column block as a dense F-order CuPy array (native dtype).""" match X: case sp.csc_matrix() | sp.csc_array(): return _csc_columns_to_gpu(X, start, stop, X.shape[0]) case sp.spmatrix() | sp.sparray(): chunk = cpsp.csc_matrix(X[:, start:stop].tocsc()) - return _sparse_to_dense(chunk, order="F").astype(cp.float64) + return _sparse_to_dense(chunk, order="F") case cpsp.csc_matrix(): return _csc_columns_to_gpu(X, start, stop, X.shape[0]) case cpsp.spmatrix(): - return _sparse_to_dense(X[:, start:stop], order="F").astype(cp.float64) + return _sparse_to_dense(X[:, start:stop], order="F") case np.ndarray() | cp.ndarray(): - return cp.asarray(X[:, start:stop], dtype=cp.float64, order="F") + return cp.asfortranarray(cp.asarray(X[:, start:stop])) case _: raise ValueError(f"Unsupported matrix type: {type(X)}") diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index f24da8f2..5f46f110 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -11,313 +11,47 @@ from rapids_singlecell._cuda import _wilcoxon_cuda as _wc -from ._utils import _choose_chunk_size - if TYPE_CHECKING: from numpy.typing import NDArray from ._core import _RankGenes MIN_GROUP_SIZE_WARNING = 25 - -# --------------------------------------------------------------------------- -# CuPy RawKernels for sort-once OVO -# --------------------------------------------------------------------------- - -_RANK_SUMS_KERNEL = cp.RawKernel( - r""" -extern "C" __global__ -void rank_sums_from_sorted( - const double* __restrict__ ref_sorted, // (n_ref, n_cols) F-order - const double* __restrict__ grp_sorted, // (n_grp, n_cols) F-order - double* __restrict__ rank_sums, // (n_cols,) - const int n_ref, - const int n_grp, - const int n_cols -) { - /* One block per gene (column). - Threads cooperatively process group elements. - For each group element, binary-search the sorted reference - and the sorted group to compute the average rank in the - combined (group + reference) set. - */ - int col = blockIdx.x; - if (col >= n_cols) return; - - const double* ref = ref_sorted + (long long)col * n_ref; - const double* grp = grp_sorted + (long long)col * n_grp; - - double local_sum = 0.0; - - for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { - double v = grp[i]; - - // --- count of ref values < v --- - int lo = 0, hi = n_ref; - while (lo < hi) { int m = (lo + hi) >> 1; if (ref[m] < v) lo = m + 1; else hi = m; } - int n_lt_ref = lo; - - // --- count of ref values <= v --- - lo = n_lt_ref; hi = n_ref; - while (lo < hi) { int m = (lo + hi) >> 1; if (ref[m] <= v) lo = m + 1; else hi = m; } - int n_eq_ref = lo - n_lt_ref; - - // --- count of grp values < v --- - lo = 0; hi = n_grp; - while (lo < hi) { int m = (lo + hi) >> 1; if (grp[m] < v) lo = m + 1; else hi = m; } - int n_lt_grp = lo; - - // --- count of grp values <= v --- - lo = n_lt_grp; hi = n_grp; - while (lo < hi) { int m = (lo + hi) >> 1; if (grp[m] <= v) lo = m + 1; else hi = m; } - int n_eq_grp = lo - n_lt_grp; - - int n_lt = n_lt_ref + n_lt_grp; - int n_eq = n_eq_ref + n_eq_grp; - double avg_rank = (double)n_lt + ((double)n_eq + 1.0) / 2.0; - local_sum += avg_rank; - } - - // --- warp-level reduction --- - #pragma unroll - for (int off = 16; off > 0; off >>= 1) - local_sum += __shfl_down_sync(0xffffffff, local_sum, off); - - __shared__ double warp_sums[32]; - int lane = threadIdx.x & 31; - int wid = threadIdx.x >> 5; - if (lane == 0) warp_sums[wid] = local_sum; - __syncthreads(); - - if (threadIdx.x < 32) { - double val = (threadIdx.x < ((blockDim.x + 31) >> 5)) - ? warp_sums[threadIdx.x] : 0.0; - #pragma unroll - for (int off = 16; off > 0; off >>= 1) - val += __shfl_down_sync(0xffffffff, val, off); - if (threadIdx.x == 0) rank_sums[col] = val; - } -} -""", - "rank_sums_from_sorted", - options=("--use_fast_math",), -) - - -_TIE_CORR_MERGE_KERNEL = cp.RawKernel( - r""" -extern "C" __global__ -void tie_correction_merged( - const double* __restrict__ ref_sorted, - const double* __restrict__ grp_sorted, - double* __restrict__ correction, - const int n_ref, - const int n_grp, - const int n_cols -) { - /* One block per gene column. Thread 0 merges the two sorted - arrays and accumulates the tie-correction term - sum(t^3 - t) over all tie groups of size t. - */ - int col = blockIdx.x; - if (col >= n_cols || threadIdx.x != 0) return; - - const double* ref = ref_sorted + (long long)col * n_ref; - const double* grp = grp_sorted + (long long)col * n_grp; - - int i = 0, j = 0; - double tie_sum = 0.0; - - while (i < n_ref || j < n_grp) { - double v; - if (j >= n_grp) v = ref[i]; - else if (i >= n_ref) v = grp[j]; - else v = (ref[i] <= grp[j]) ? ref[i] : grp[j]; - - int count = 0; - while (i < n_ref && ref[i] == v) { ++i; ++count; } - while (j < n_grp && grp[j] == v) { ++j; ++count; } - - if (count > 1) { - double t = (double)count; - tie_sum += t * t * t - t; - } - } - - int n = n_ref + n_grp; - double dn = (double)n; - double denom = dn * dn * dn - dn; - correction[col] = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; -} -""", - "tie_correction_merged", - options=("--use_fast_math",), -) - - -_CSR_EXTRACT_KERNEL = cp.RawKernel( - r""" -extern "C" __global__ -void csr_extract_dense( - const double* __restrict__ data, - const int* __restrict__ indices, - const long long* __restrict__ indptr, - const int* __restrict__ row_ids, - double* __restrict__ out, // F-order (n_target, n_cols) - const int n_target, - const int col_start, - const int col_stop, - const int n_cols // = col_stop - col_start -) { - /* One thread per target row. - Binary-search the CSR index array for col_start, then - linear-scan through [col_start, col_stop) writing to - the dense output in column-major order. - */ - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n_target) return; - - int row = row_ids[tid]; - long long rs = indptr[row]; - long long re = indptr[row + 1]; - - // binary search for col_start - long long lo = rs, hi = re; - while (lo < hi) { - long long m = (lo + hi) >> 1; - if (indices[m] < col_start) lo = m + 1; else hi = m; - } - - for (long long p = lo; p < re; ++p) { - int c = indices[p]; - if (c >= col_stop) break; - int lc = c - col_start; - out[(long long)lc * n_target + tid] = data[p]; - } -} -""", - "csr_extract_dense", - options=("--use_fast_math",), -) +STREAMING_SUB_BATCH = 64 # --------------------------------------------------------------------------- -# Kernel helpers +# Helpers # --------------------------------------------------------------------------- -WARP_SIZE = 32 -MAX_THREADS = 512 - - -def _round_up_to_warp(n: int) -> int: - return min(MAX_THREADS, ((n + WARP_SIZE - 1) // WARP_SIZE) * WARP_SIZE) - - -def _rank_sums_searchsorted( - ref_sorted: cp.ndarray, - grp_sorted: cp.ndarray, -) -> cp.ndarray: - """Rank sums for *grp* via binary search in pre-sorted *ref*. - - Both must be F-order float64 ``(n_rows, n_cols)``. - """ - n_ref, n_cols = ref_sorted.shape - n_grp = grp_sorted.shape[0] - rank_sums = cp.empty(n_cols, dtype=cp.float64) - threads = _round_up_to_warp(min(n_grp, MAX_THREADS)) - _RANK_SUMS_KERNEL( - (n_cols,), - (threads,), - ( - ref_sorted, - grp_sorted, - rank_sums, - np.int32(n_ref), - np.int32(n_grp), - np.int32(n_cols), - ), - stream=cp.cuda.get_current_stream(), - ) - return rank_sums +def _to_gpu_native(X, n_rows: int, n_cols: int): + """Move *X* to GPU, preserving its format (CSR / CSC / dense).""" + # Already on GPU + if isinstance(X, cp.ndarray): + return X + if cpsp.issparse(X): + return X -def _tie_correction_merged( - ref_sorted: cp.ndarray, - grp_sorted: cp.ndarray, -) -> cp.ndarray: - """Tie-correction factor via merge of two sorted F-order arrays.""" - n_ref, n_cols = ref_sorted.shape - n_grp = grp_sorted.shape[0] - correction = cp.empty(n_cols, dtype=cp.float64) - _TIE_CORR_MERGE_KERNEL( - (n_cols,), - (1,), - ( - ref_sorted, - grp_sorted, - correction, - np.int32(n_ref), - np.int32(n_grp), - np.int32(n_cols), - ), - stream=cp.cuda.get_current_stream(), - ) - return correction + # Host sparse → GPU sparse, same format + if isinstance(X, sp.spmatrix | sp.sparray): + if sp.issparse(X) and X.format == "csc": + csc = X if X.format == "csc" else X.tocsc() + return cpsp.csc_matrix( + (cp.asarray(csc.data), cp.asarray(csc.indices), cp.asarray(csc.indptr)), + shape=(n_rows, n_cols), + ) + csr = X.tocsr() if X.format != "csr" else X + return cpsp.csr_matrix( + (cp.asarray(csr.data), cp.asarray(csr.indices), cp.asarray(csr.indptr)), + shape=(n_rows, n_cols), + ) + # Host dense → GPU dense + if isinstance(X, np.ndarray): + return cp.asarray(X) -def _extract_dense_block_csr_gpu( - data: cp.ndarray, - indices: cp.ndarray, - indptr: cp.ndarray, - row_ids: cp.ndarray, - *, - col_start: int, - col_stop: int, -) -> cp.ndarray: - """Extract a dense F-order float64 block from GPU CSR arrays.""" - n_target = row_ids.shape[0] - n_cols = col_stop - col_start - out = cp.zeros((n_target, n_cols), dtype=cp.float64, order="F") - if n_target == 0 or n_cols == 0: - return out - threads = _round_up_to_warp(min(n_target, MAX_THREADS)) - blocks = (n_target + threads - 1) // threads - _CSR_EXTRACT_KERNEL( - (blocks,), - (threads,), - ( - data, - indices, - indptr, - row_ids, - out, - np.int32(n_target), - np.int32(col_start), - np.int32(col_stop), - np.int32(n_cols), - ), - stream=cp.cuda.get_current_stream(), - ) - return out - - -def _to_gpu_csr_arrays(X) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: - """Return (data, indices, indptr) as float64/int32/int64 on GPU.""" - if isinstance(X, cpsp.csr_matrix): - csr = X - elif isinstance(X, cpsp.csc_matrix): - csr = X.tocsr() - elif isinstance(X, sp.spmatrix | sp.sparray): - if X.format != "csr": - X = X.tocsr() - csr = cpsp.csr_matrix(X) - else: - raise TypeError(f"Expected sparse matrix, got {type(X)}") - return ( - csr.data.astype(cp.float64, copy=False), - csr.indices.astype(cp.int32, copy=False), - csr.indptr.astype(cp.int64, copy=False), - ) + raise TypeError(f"Unsupported matrix type: {type(X)}") def _extract_dense_block( @@ -328,26 +62,41 @@ def _extract_dense_block( *, csr_arrays: tuple[cp.ndarray, cp.ndarray, cp.ndarray] | None = None, ) -> cp.ndarray: - """Extract ``X[row_ids, start:stop]`` as dense F-order float64 on GPU.""" + """Extract ``X[row_ids, start:stop]`` as dense F-order on GPU (native dtype). + + The CSR kernel path outputs float64 (kernel writes double*). + All other paths preserve the input dtype. + """ if csr_arrays is not None: data, indices, indptr = csr_arrays if row_ids is None: n_target = int(indptr.shape[0] - 1) row_ids = cp.arange(n_target, dtype=cp.int32) - return _extract_dense_block_csr_gpu( - data, indices, indptr, row_ids, col_start=start, col_stop=stop - ) + n_target = row_ids.shape[0] + n_cols = stop - start + out = cp.zeros((n_target, n_cols), dtype=cp.float64, order="F") + if n_target > 0 and n_cols > 0: + _wc.csr_extract_dense( + data, + indices, + indptr, + row_ids, + out, + n_target=n_target, + col_start=start, + col_stop=stop, + stream=cp.cuda.get_current_stream().ptr, + ) + return out if isinstance(X, np.ndarray): if row_ids is not None: - return cp.asarray( - X[cp.asnumpy(row_ids), start:stop], dtype=cp.float64, order="F" - ) - return cp.asarray(X[:, start:stop], dtype=cp.float64, order="F") + return cp.asarray(X[cp.asnumpy(row_ids), start:stop], order="F") + return cp.asarray(X[:, start:stop], order="F") if isinstance(X, cp.ndarray): chunk = X[row_ids, start:stop] if row_ids is not None else X[:, start:stop] - return cp.asfortranarray(chunk.astype(cp.float64, copy=False)) + return cp.asfortranarray(chunk) if isinstance(X, sp.spmatrix | sp.sparray): if row_ids is not None: @@ -355,53 +104,58 @@ def _extract_dense_block( chunk = X[idx][:, start:stop].toarray() else: chunk = X[:, start:stop].toarray() - return cp.asarray(chunk, dtype=cp.float64, order="F") + return cp.asarray(chunk, order="F") if cpsp.issparse(X): if row_ids is not None: chunk = X[row_ids][:, start:stop].toarray() else: chunk = X[:, start:stop].toarray() - return cp.asfortranarray(chunk.astype(cp.float64, copy=False)) + return cp.asfortranarray(chunk) raise TypeError(f"Unsupported matrix type: {type(X)}") -# --------------------------------------------------------------------------- -# Existing kernels (OVR path) -# --------------------------------------------------------------------------- - +def _segmented_sort_columns( + data: cp.ndarray, + offsets_host: np.ndarray, + n_rows: int, + n_cols: int, + n_groups: int, +) -> cp.ndarray: + """Sort each group segment within each column using CUB radix sort. -def _average_ranks( - matrix: cp.ndarray, *, return_sorted: bool = False -) -> cp.ndarray | tuple[cp.ndarray, cp.ndarray]: - """Compute average ranks for each column using GPU kernel.""" - n_rows, n_cols = matrix.shape - sorter = cp.argsort(matrix, axis=0) - sorted_vals = cp.take_along_axis(matrix, sorter, axis=0) - sorted_vals = cp.asfortranarray(sorted_vals) - sorter = cp.asfortranarray(sorter.astype(cp.int32)) - stream = cp.cuda.get_current_stream().ptr - _wc.average_rank( - sorted_vals, sorter, matrix, n_rows=n_rows, n_cols=n_cols, stream=stream + Sorts in float32 for half the bandwidth. Returns float32 F-order. + """ + n_items = n_rows * n_cols + n_segments = n_cols * n_groups + + col_bases = np.arange(n_cols, dtype=np.int32) * n_rows + seg_starts = col_bases[:, None] + offsets_host[None, :n_groups] + seg_arr = np.empty(n_segments + 1, dtype=np.int32) + seg_arr[:n_segments] = seg_starts.ravel() + seg_arr[n_segments] = n_items + seg_offsets_gpu = cp.asarray(seg_arr) + + temp_bytes = _wc.get_seg_sort_temp_bytes(n_items=n_items, n_segments=n_segments) + cub_temp = cp.empty(temp_bytes, dtype=cp.uint8) + + keys_in = cp.ascontiguousarray(data.astype(cp.float32).ravel(order="F")) + keys_out = cp.empty_like(keys_in) + + _wc.segmented_sort( + keys_in, + keys_out, + seg_offsets_gpu, + cub_temp, + n_items=n_items, + n_segments=n_segments, + stream=cp.cuda.get_current_stream().ptr, ) - if return_sorted: - return matrix, sorted_vals - return matrix - - -def _tie_correction(sorted_vals: cp.ndarray) -> cp.ndarray: - """Tie correction factor from pre-sorted values (F-order).""" - n_rows, n_cols = sorted_vals.shape - correction = cp.ones(n_cols, dtype=cp.float64) - if n_rows < 2: - return correction - sorted_vals = cp.asfortranarray(sorted_vals) - stream = cp.cuda.get_current_stream().ptr - _wc.tie_correction( - sorted_vals, correction, n_rows=n_rows, n_cols=n_cols, stream=stream + + return cp.ndarray( + (n_rows, n_cols), dtype=cp.float32, memptr=keys_out.data, order="F" ) - return correction # --------------------------------------------------------------------------- @@ -417,12 +171,12 @@ def wilcoxon( chunk_size: int | None = None, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" - rg._basic_stats() X = rg.X - n_cells, n_total_genes = rg.X.shape + n_cells, n_total_genes = X.shape group_sizes = rg.group_sizes if rg.ireference is not None: + rg._init_stats_arrays(n_total_genes) return _wilcoxon_with_reference( rg, X, @@ -432,6 +186,7 @@ def wilcoxon( use_continuity=use_continuity, chunk_size=chunk_size, ) + rg._basic_stats() return _wilcoxon_vs_rest( rg, X, @@ -445,7 +200,7 @@ def wilcoxon( # --------------------------------------------------------------------------- -# One-vs-rest (unchanged from main) +# One-vs-rest # --------------------------------------------------------------------------- @@ -460,10 +215,12 @@ def _wilcoxon_vs_rest( use_continuity: bool, chunk_size: int | None, ) -> list[tuple[int, NDArray, NDArray]]: - """Wilcoxon test: each group vs rest of cells.""" - from rapids_singlecell._utils._csr_to_csc import _fast_csr_to_csc + """Wilcoxon test: each group vs rest of cells. - from ._utils import _get_column_block + Dispatches to CSR, CSC, or dense streaming kernel based on input format. + No unnecessary format conversions. + """ + from rapids_singlecell._cuda import _wilcoxon_streaming_cuda as _ws n_groups = len(rg.groups_order) @@ -477,70 +234,134 @@ def _wilcoxon_vs_rest( stacklevel=4, ) - codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int64) - group_matrix = cp.zeros((n_cells, n_groups), dtype=cp.float64) - valid_idx = cp.where(codes_gpu < n_groups)[0] - group_matrix[valid_idx, codes_gpu[valid_idx]] = 1.0 - + group_codes = rg.group_codes.astype(np.int32, copy=False) group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) rest_sizes = n_cells - group_sizes_dev - chunk_width = _choose_chunk_size(chunk_size) - - all_scores: dict[int, list] = {i: [] for i in range(n_groups)} - all_pvals: dict[int, list] = {i: [] for i in range(n_groups)} - - if isinstance(X, sp.spmatrix | sp.sparray): - X = _fast_csr_to_csc(X) if X.format == "csr" else X.tocsc() + # Determine host-streaming eligibility BEFORE transferring + host_csc = isinstance(X, sp.spmatrix | sp.sparray) and X.format == "csc" + host_dense = isinstance(X, np.ndarray) + + if host_csc or host_dense: + # Host-streaming: sort+rank stays on host→GPU per sub-batch. + # Stats still need Aggregate on GPU — cheap one-time transfer. + # _basic_stats was already called by wilcoxon() which set + # _compute_stats_in_chunks=True for host data. Transfer a + # lightweight GPU copy just for Aggregate, then discard it. + if rg._compute_stats_in_chunks: + X_gpu_tmp = _to_gpu_native(X, n_cells, n_total_genes) + rg.X = X_gpu_tmp + rg._compute_stats_in_chunks = False + rg._basic_stats() + del X_gpu_tmp + + rank_sums_np = np.empty((n_groups, n_total_genes), dtype=np.float64) + tie_corr_np = np.ones(n_total_genes, dtype=np.float64) + + if host_csc: + _ws.ovr_streaming_csc_host( + X.data.astype(np.float32, copy=False), + X.indices.astype(np.int32, copy=False), + X.indptr.astype(np.int32, copy=False), + group_codes, + rank_sums_np, + tie_corr_np, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + else: + _ws.ovr_streaming_dense_host( + np.asfortranarray(X.astype(np.float32, copy=False)), + group_codes, + rank_sums_np, + tie_corr_np, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) - for start in range(0, n_total_genes, chunk_width): - stop = min(start + chunk_width, n_total_genes) + rank_sums = cp.asarray(rank_sums_np) + tie_corr = cp.asarray(tie_corr_np) + else: + # GPU data or host CSR → transfer to GPU, use GPU kernels + X_gpu = _to_gpu_native(X, n_cells, n_total_genes) + + if rg._compute_stats_in_chunks: + rg.X = X_gpu + rg._compute_stats_in_chunks = False + rg._basic_stats() + + group_codes_gpu = cp.asarray(group_codes) + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + + if cpsp.isspmatrix_csr(X_gpu): + _ws.ovr_streaming_csr( + X_gpu.data.astype(cp.float32, copy=False), + X_gpu.indices.astype(cp.int32, copy=False), + X_gpu.indptr.astype(cp.int32, copy=False), + group_codes_gpu, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + elif cpsp.isspmatrix_csc(X_gpu): + _ws.ovr_streaming_csc( + X_gpu.data.astype(cp.float32, copy=False), + X_gpu.indices.astype(cp.int32, copy=False), + X_gpu.indptr.astype(cp.int32, copy=False), + group_codes_gpu, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + else: + dense_f32 = cp.asfortranarray(X_gpu.astype(cp.float32, copy=False)) + _ws.ovr_streaming( + dense_f32, + group_codes_gpu, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) - block = _get_column_block(X, start, stop) + # Z-scores + p-values (vectorised) + expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 + variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] + variance *= (n_cells + 1) / 12.0 + std = cp.sqrt(variance) + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / std + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - rg._accumulate_chunk_stats_vs_rest( - block, - start, - stop, - group_matrix=group_matrix, - group_sizes_dev=group_sizes_dev, - n_cells=n_cells, - ) + all_z = z.get() + all_p = p_values.get() - if tie_correct: - ranks, sorted_vals = _average_ranks(block, return_sorted=True) - tie_corr = _tie_correction(sorted_vals) - else: - ranks = _average_ranks(block) - tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64) - - rank_sums = group_matrix.T @ ranks - expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 - variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] - variance *= (n_cells + 1) / 12.0 - std = cp.sqrt(variance) - diff = rank_sums - expected - if use_continuity: - diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / std - cp.nan_to_num(z, copy=False) - p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - - z_host = z.get() - p_host = p_values.get() - - for idx in range(n_groups): - all_scores[idx].append(z_host[idx]) - all_pvals[idx].append(p_host[idx]) - - return [ - (gi, np.concatenate(all_scores[gi]), np.concatenate(all_pvals[gi])) - for gi in range(n_groups) - ] + return [(gi, all_z[gi], all_p[gi]) for gi in range(n_groups)] # --------------------------------------------------------------------------- -# One-vs-reference (sort-once optimisation inspired by illico) +# One-vs-reference # --------------------------------------------------------------------------- @@ -556,40 +377,50 @@ def _wilcoxon_with_reference( ) -> list[tuple[int, NDArray, NDArray]]: """Wilcoxon test: each group vs a specific reference group. - Key optimisations over the naive per-group approach: - - * **No CSR->CSC conversion** -- data is read directly from CSR via a - binary-search extraction kernel. - * **Reference sorted once per gene chunk** -- the (typically large) - reference group is extracted and column-sorted once, then reused - for every test-group comparison. - * **Rank sums via binary search** -- instead of concatenating and - re-sorting reference + group for every pair, a GPU kernel computes - rank sums by binary-searching the pre-sorted reference. This - reduces the per-group cost from O((n_ref+n_grp) log(n_ref+n_grp)) - to O(n_grp log(n_ref)). + All test groups are processed in a single batched streaming kernel, + eliminating per-group kernel launch overhead. """ + from rapids_singlecell._cuda import _wilcoxon_streaming_cuda as _ws + + n_cells = X.shape[0] n_groups = len(rg.groups_order) ireference = rg.ireference n_ref = int(group_sizes[ireference]) - - # ---- build row-index arrays (GPU int32) for every group ---- codes = rg.group_codes + + # ---- build row-index arrays ---- ref_row_ids = cp.asarray(np.where(codes == ireference)[0], dtype=cp.int32) - group_row_ids: dict[int, cp.ndarray] = {} + test_group_indices: list[int] = [] + all_grp_rows: list[np.ndarray] = [] + offsets = [0] for gi in range(n_groups): if gi == ireference: continue - group_row_ids[gi] = cp.asarray(np.where(codes == gi)[0], dtype=cp.int32) + rows = np.where(codes == gi)[0] + test_group_indices.append(gi) + all_grp_rows.append(rows) + offsets.append(offsets[-1] + len(rows)) + + all_grp_row_ids = cp.asarray(np.concatenate(all_grp_rows), dtype=cp.int32) + grp_offsets_gpu = cp.asarray(offsets, dtype=cp.int32) + n_test = len(test_group_indices) - # ---- prepare CSR arrays on GPU if sparse (one-time transfer) ---- + # ---- move data to GPU ---- + X_gpu = _to_gpu_native(X, n_cells, n_total_genes) + + # For row extraction, CSR kernel is optimal. Dense uses cupy indexing. csr_arrays = None - if sp.issparse(X) or cpsp.issparse(X): - csr_arrays = _to_gpu_csr_arrays(X) + if cpsp.issparse(X_gpu): + csr_gpu = X_gpu.tocsr() if not cpsp.isspmatrix_csr(X_gpu) else X_gpu + csr_arrays = ( + csr_gpu.data.astype(cp.float64, copy=False), + csr_gpu.indices.astype(cp.int32, copy=False), + csr_gpu.indptr.astype(cp.int32, copy=False), + ) # ---- warn for small groups ---- - for gi in group_row_ids: + for gi in test_group_indices: n_group = int(group_sizes[gi]) if n_group <= MIN_GROUP_SIZE_WARNING or n_ref <= MIN_GROUP_SIZE_WARNING: warnings.warn( @@ -600,84 +431,154 @@ def _wilcoxon_with_reference( stacklevel=4, ) - # ---- pre-allocate outputs ---- - all_scores: dict[int, np.ndarray] = {} - all_pvals: dict[int, np.ndarray] = {} - for gi in group_row_ids: - all_scores[gi] = np.empty(n_total_genes, dtype=np.float64) - all_pvals[gi] = np.empty(n_total_genes, dtype=np.float64) + test_sizes = cp.asarray( + [group_sizes[gi] for gi in test_group_indices], dtype=cp.float64 + ) - chunk_width = _choose_chunk_size(chunk_size) + # ---- extract ref + grp blocks (one-shot, all genes) ---- + ref_block = _extract_dense_block( + X_gpu, ref_row_ids, 0, n_total_genes, csr_arrays=csr_arrays + ) + grp_block = _extract_dense_block( + X_gpu, all_grp_row_ids, 0, n_total_genes, csr_arrays=csr_arrays + ) + n_all_grp = grp_block.shape[0] - # ---- chunk loop (outer) x group loop (inner) ---- - for start in range(0, n_total_genes, chunk_width): - stop = min(start + chunk_width, n_total_genes) - n_cols = stop - start + # ---- stats via fused kernel ---- + _compute_grouped_stats( + rg, + ireference, + ref_block, + n_ref, + test_group_indices=test_group_indices, + grp_block=grp_block, + grp_offsets_gpu=grp_offsets_gpu, + n_test=n_test, + n_cols=n_total_genes, + start=0, + stop=n_total_genes, + ) - # Extract & sort reference columns ONCE per chunk - ref_block = _extract_dense_block( - X, ref_row_ids, start, stop, csr_arrays=csr_arrays - ) - ref_sorted = cp.asfortranarray(cp.sort(ref_block, axis=0)) - - # Accumulate reference stats once per chunk (CPU-data path) - if rg._compute_stats_in_chunks and start not in rg._ref_chunk_computed: - rg._ref_chunk_computed.add(start) - ref_mean = ref_block.mean(axis=0) - rg.means[ireference, start:stop] = cp.asnumpy(ref_mean) - if n_ref > 1: - ref_var = ref_block.var(axis=0, ddof=1) - rg.vars[ireference, start:stop] = cp.asnumpy(ref_var) - if rg.comp_pts: - ref_nnz = (ref_block != 0).sum(axis=0) - rg.pts[ireference, start:stop] = cp.asnumpy(ref_nnz / n_ref) - - for gi, grp_rows in group_row_ids.items(): - n_group = int(group_sizes[gi]) - n_combined = n_group + n_ref - - # Extract & sort group columns (small, fast) - grp_block = _extract_dense_block( - X, grp_rows, start, stop, csr_arrays=csr_arrays - ) - grp_sorted = cp.asfortranarray(cp.sort(grp_block, axis=0)) - - # Accumulate group stats (CPU-data path) - if rg._compute_stats_in_chunks: - grp_mean = grp_block.mean(axis=0) - rg.means[gi, start:stop] = cp.asnumpy(grp_mean) - if n_group > 1: - grp_var = grp_block.var(axis=0, ddof=1) - rg.vars[gi, start:stop] = cp.asnumpy(grp_var) - if rg.comp_pts: - grp_nnz = (grp_block != 0).sum(axis=0) - rg.pts[gi, start:stop] = cp.asnumpy(grp_nnz / n_group) - - # ---- rank sums via binary search (no combined sort) ---- - rank_sums = _rank_sums_searchsorted(ref_sorted, grp_sorted) - - # ---- tie correction (optional) ---- - if tie_correct: - tie_corr = _tie_correction_merged(ref_sorted, grp_sorted) - else: - tie_corr = cp.ones(n_cols, dtype=cp.float64) - - # ---- z-scores & p-values ---- - expected = n_group * (n_combined + 1) / 2.0 - variance = tie_corr * n_group * n_ref * (n_combined + 1) / 12.0 - diff = rank_sums - expected - if use_continuity: - diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / cp.sqrt(variance) - cp.nan_to_num(z, copy=False) - p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - - all_scores[gi][start:stop] = z.get() - all_pvals[gi][start:stop] = p_values.get() - - # ---- return in group order ---- - return [ - (gi, all_scores[gi], all_pvals[gi]) - for gi in range(n_groups) - if gi != ireference - ] + # ---- sort reference once ---- + ref_sorted = _segmented_sort_columns( + ref_block, np.array([0, n_ref], dtype=np.int32), n_ref, n_total_genes, 1 + ) + + # ---- streaming OVO: sort groups + binary search rank sums ---- + grp_f32 = cp.asfortranarray(grp_block.astype(cp.float32)) + rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.empty((n_test, n_total_genes), dtype=cp.float64) + + _ws.ovo_streaming( + ref_sorted, + grp_f32, + grp_offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + ) + + # ---- z-scores & p-values (vectorised) ---- + n_combined = test_sizes + n_ref + expected = test_sizes * (n_combined + 1) / 2.0 + variance = test_sizes * n_ref * (n_combined + 1) / 12.0 + if tie_correct: + variance = variance[:, None] * tie_corr_arr + else: + variance = cp.broadcast_to(variance[:, None], (n_test, n_total_genes)).copy() + + diff = rank_sums - expected[:, None] + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / cp.sqrt(variance) + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + + all_z = z.get() + all_p = p_values.get() + + return [(gi, all_z[ti], all_p[ti]) for ti, gi in enumerate(test_group_indices)] + + +def _compute_grouped_stats( + rg: _RankGenes, + ireference: int, + ref_block: cp.ndarray, + n_ref: int, + *, + test_group_indices: list[int], + grp_block: cp.ndarray, + grp_offsets_gpu: cp.ndarray, + n_test: int, + n_cols: int, + start: int, + stop: int, +) -> None: + """Compute mean/var/pts for ref + all test groups via fused C++ kernel.""" + s = slice(start, stop) + stream = cp.cuda.get_current_stream().ptr + + # Reference stats (single "group") + ref_offsets = cp.array([0, n_ref], dtype=cp.int32) + ref_sums = cp.empty((1, n_cols), dtype=cp.float64) + ref_sq = cp.empty((1, n_cols), dtype=cp.float64) + ref_nnz = cp.empty((1, n_cols), dtype=cp.float64) + _wc.grouped_stats( + ref_block, + ref_offsets, + ref_sums, + ref_sq, + ref_nnz, + n_all_rows=n_ref, + n_cols=n_cols, + n_groups=1, + compute_nnz=rg.comp_pts, + stream=stream, + ) + + rg.means[ireference, s] = cp.asnumpy(ref_sums[0] / n_ref) + if n_ref > 1: + var = (ref_sq[0] - ref_sums[0] ** 2 / n_ref) / (n_ref - 1) + rg.vars[ireference, s] = cp.asnumpy(cp.maximum(var, 0)) + if rg.comp_pts: + rg.pts[ireference, s] = cp.asnumpy(ref_nnz[0] / n_ref) + + # All test groups in one kernel launch + n_all_grp = grp_block.shape[0] + grp_sums = cp.empty((n_test, n_cols), dtype=cp.float64) + grp_sq = cp.empty((n_test, n_cols), dtype=cp.float64) + grp_nnz = cp.empty((n_test, n_cols), dtype=cp.float64) + _wc.grouped_stats( + grp_block, + grp_offsets_gpu, + grp_sums, + grp_sq, + grp_nnz, + n_all_rows=n_all_grp, + n_cols=n_cols, + n_groups=n_test, + compute_nnz=rg.comp_pts, + stream=stream, + ) + + # Vectorised mean/var computation on GPU, single D2H transfer + sizes = cp.asarray( + [rg.group_sizes[gi] for gi in test_group_indices], dtype=cp.float64 + )[:, None] + means = grp_sums / sizes + vars_ = cp.maximum((grp_sq - grp_sums**2 / sizes) / cp.maximum(sizes - 1, 1), 0) + + means_host = cp.asnumpy(means) + vars_host = cp.asnumpy(vars_) + for ti, gi in enumerate(test_group_indices): + rg.means[gi, s] = means_host[ti] + rg.vars[gi, s] = vars_host[ti] + + if rg.comp_pts: + pts_host = cp.asnumpy(grp_nnz / sizes) + for ti, gi in enumerate(test_group_indices): + rg.pts[gi, s] = pts_host[ti] diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 0c6844da..455441e1 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -1,12 +1,11 @@ from __future__ import annotations -import cupy as cp import numpy as np import pandas as pd import pytest import scanpy as sc import scipy.sparse as sp -from scipy.stats import mannwhitneyu, rankdata, tiecorrect +from scipy.stats import mannwhitneyu import rapids_singlecell as rsc @@ -441,188 +440,3 @@ def test_sparse_matches_dense(self, perturbation_adata, sparse): np.testing.assert_array_equal( dense_df["pvals"].values, sparse_df["pvals"].values ) - - -# ============================================================================ -# Tests for ranking and tie correction kernels (edge cases from scipy) -# ============================================================================ - - -class TestRankingKernel: - """Tests for _average_ranks based on scipy.stats.rankdata edge cases.""" - - @pytest.fixture - def average_ranks(self): - """Import the ranking function.""" - from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( - _average_ranks, - ) - - return _average_ranks - - @staticmethod - def _to_gpu(values): - """Convert 1D values to GPU column matrix with F-order.""" - arr = np.asarray(values, dtype=np.float64).reshape(-1, 1) - return cp.asarray(arr, order="F") - - def test_basic_ranking(self, average_ranks): - """Test basic average ranking on simple data.""" - values = [3.0, 1.0, 2.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_all_ties(self, average_ranks): - """All identical values should get the average rank.""" - values = [5.0, 5.0, 5.0, 5.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_no_ties(self, average_ranks): - """All unique values should get sequential ranks.""" - values = [1.0, 2.0, 3.0, 4.0, 5.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_mixed_ties(self, average_ranks): - """Mix of ties and unique values.""" - values = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_negative_values(self, average_ranks): - """Test with negative values.""" - values = [-3.0, -1.0, -2.0, 0.0, 1.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_single_element(self, average_ranks): - """Single element should have rank 1.""" - values = [42.0] - result = average_ranks(self._to_gpu(values)) - np.testing.assert_allclose(result.get().flatten(), [1.0]) - - def test_two_elements_tied(self, average_ranks): - """Two tied elements should both have rank 1.5.""" - values = [7.0, 7.0] - result = average_ranks(self._to_gpu(values)) - np.testing.assert_allclose(result.get().flatten(), [1.5, 1.5]) - - def test_multiple_columns(self, average_ranks): - """Test ranking across multiple columns independently.""" - col0 = [3.0, 1.0, 2.0] - col1 = [1.0, 1.0, 2.0] - data = np.column_stack([col0, col1]).astype(np.float64) - result = average_ranks(cp.asarray(data, order="F")) - - np.testing.assert_allclose(result.get()[:, 0], rankdata(col0, method="average")) - np.testing.assert_allclose(result.get()[:, 1], rankdata(col1, method="average")) - - -class TestTieCorrectionKernel: - """Tests for _tie_correction based on scipy.stats.tiecorrect edge cases.""" - - @pytest.fixture - def tie_correction(self): - """Import the tie correction function and ranking function.""" - from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( - _average_ranks, - _tie_correction, - ) - - return _tie_correction, _average_ranks - - @staticmethod - def _to_gpu(values): - """Convert 1D values to GPU column matrix with F-order.""" - arr = np.asarray(values, dtype=np.float64).reshape(-1, 1) - return cp.asarray(arr, order="F") - - def test_no_ties(self, tie_correction): - """No ties should give correction factor 1.0.""" - _tie_correction, _average_ranks = tie_correction - - values = [1.0, 2.0, 3.0, 4.0, 5.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_all_ties(self, tie_correction): - """All tied values should give correction factor 0.0.""" - _tie_correction, _average_ranks = tie_correction - - values = [5.0, 5.0, 5.0, 5.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_mixed_ties(self, tie_correction): - """Mix of ties should give intermediate correction factor.""" - _tie_correction, _average_ranks = tie_correction - - values = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_two_elements_tied(self, tie_correction): - """Two tied elements.""" - _tie_correction, _average_ranks = tie_correction - - values = [7.0, 7.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_single_element(self, tie_correction): - """Single element should give correction factor 1.0.""" - _tie_correction, _average_ranks = tie_correction - - values = [42.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - # Single element: n^3 - n = 0, so formula gives 1.0 - np.testing.assert_allclose(result.get()[0], 1.0, rtol=1e-10) - - def test_multiple_columns(self, tie_correction): - """Test tie correction across multiple columns independently.""" - _tie_correction, _average_ranks = tie_correction - - col0 = [1.0, 2.0, 3.0] # No ties - col1 = [5.0, 5.0, 5.0] # All ties - data = np.column_stack([col0, col1]).astype(np.float64) - _, sorted_vals = _average_ranks(cp.asarray(data, order="F"), return_sorted=True) - result = _tie_correction(sorted_vals) - - np.testing.assert_allclose( - result.get()[0], tiecorrect(rankdata(col0)), rtol=1e-10 - ) - np.testing.assert_allclose( - result.get()[1], tiecorrect(rankdata(col1)), rtol=1e-10 - ) - - def test_large_tie_groups(self, tie_correction): - """Test with large tie groups.""" - _tie_correction, _average_ranks = tie_correction - - # 50 values of 1, 50 values of 2 (non-multiple of 32 to test warp handling) - values = [1.0] * 50 + [2.0] * 50 - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) From 9fa9c981b8b03b6e5edc64445106802dd717f9e7 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 13 Apr 2026 19:28:15 +0200 Subject: [PATCH 03/21] try add rmm --- .github/workflows/publish.yml | 5 +- CMakeLists.txt | 41 +- pyproject.toml | 18 +- src/rapids_singlecell/_cuda/__init__.py | 15 +- .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 401 +++-- .../_cuda/wilcoxon/wilcoxon.cu | 92 - .../_cuda/wilcoxon/wilcoxon_common.cuh | 87 + .../_cuda/wilcoxon/wilcoxon_ovo.cu | 1484 +++++++++++++++++ .../wilcoxon_ovr.cu} | 747 ++------- .../tools/_rank_genes_groups/_wilcoxon.py | 378 +++-- 10 files changed, 2263 insertions(+), 1005 deletions(-) delete mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu rename src/rapids_singlecell/_cuda/{wilcoxon_streaming/wilcoxon_streaming.cu => wilcoxon/wilcoxon_ovr.cu} (50%) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 3f2e4447..d7ed8fd9 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -112,14 +112,15 @@ jobs: CIBW_ENVIRONMENT_PASS_LINUX: SETUPTOOLS_SCM_PRETEND_VERSION CIBW_ENVIRONMENT: > CUDA_PATH=/usr/local/cuda - LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH + LD_LIBRARY_PATH=/usr/local/cuda/lib64:$(python3 -c "import sysconfig,os;sp=sysconfig.get_path('purelib');print(os.path.join(sp,'librmm','lib64')+':'+os.path.join(sp,'rapids_logger','lib64'))" 2>/dev/null || echo ""):$LD_LIBRARY_PATH PATH=/usr/local/cuda/bin:$PATH CIBW_BEFORE_BUILD: > python -m pip install -U pip scikit-build-core cmake ninja nanobind + librmm-cu${{ matrix.cuda_major }} CIBW_TEST_SKIP: "*" CIBW_TEST_COMMAND: "" - CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} -w {dest_dir} {wheel}" + CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}" CIBW_BUILD_VERBOSITY: "1" - uses: actions/upload-artifact@v4 diff --git a/CMakeLists.txt b/CMakeLists.txt index 3c0d2e8a..3bf8fe14 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,39 @@ if (RSC_BUILD_EXTENSIONS) find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT}) find_package(nanobind CONFIG REQUIRED) find_package(CUDAToolkit REQUIRED) + + # Find librmm cmake config. + # Works with conda, pixi, uv, venv — uses env root to find site-packages. + # Priority: LIBRMM_DIR env var > CONDA_PREFIX > VIRTUAL_ENV > Python prefix. + set(_env_roots "") + if(DEFINED ENV{LIBRMM_DIR}) + list(APPEND _env_roots "$ENV{LIBRMM_DIR}/..") + endif() + foreach(_var CONDA_PREFIX VIRTUAL_ENV PIXI_PROJECT_ROOT) + if(DEFINED ENV{${_var}}) + list(APPEND _env_roots "$ENV{${_var}}") + endif() + endforeach() + # Fallback: Python prefix (works for any env manager) + execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import sys; print(sys.prefix)" + OUTPUT_VARIABLE _py_prefix OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) + if(_py_prefix) + list(APPEND _env_roots "${_py_prefix}") + endif() + foreach(_root ${_env_roots}) + file(GLOB _hints "${_root}/lib/cmake/rmm" + "${_root}/lib/python*/site-packages/librmm/lib*/cmake/rmm" + "${_root}/lib/python*/site-packages/rapids_logger/lib*/cmake/rapids_logger") + foreach(_h ${_hints}) + get_filename_component(_dir "${_h}" DIRECTORY) + list(APPEND CMAKE_PREFIX_PATH "${_dir}") + endforeach() + endforeach() + find_package(rmm CONFIG) + if(NOT rmm_FOUND) + message(WARNING "librmm not found — wilcoxon will use cudaMalloc fallback") + endif() message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") else() message(STATUS "RSC_BUILD_EXTENSIONS=OFF -> skipping compiled extensions for docs") @@ -84,7 +117,8 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_edistance_cuda src/rapids_singlecell/_cuda/edistance/edistance.cu) add_nb_cuda_module(_hvg_cuda src/rapids_singlecell/_cuda/hvg/hvg.cu) add_nb_cuda_module(_kde_cuda src/rapids_singlecell/_cuda/kde/kde.cu) - add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) + add_nb_cuda_module(_wilcoxon_ovr_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu) + add_nb_cuda_module(_wilcoxon_ovo_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu) # Harmony CUDA modules add_nb_cuda_module(_harmony_scatter_cuda src/rapids_singlecell/_cuda/harmony/scatter/scatter.cu) add_nb_cuda_module(_harmony_outer_cuda src/rapids_singlecell/_cuda/harmony/outer/outer.cu) @@ -100,5 +134,8 @@ if (RSC_BUILD_EXTENSIONS) target_link_libraries(_harmony_correction_batched_cuda PRIVATE CUDA::cublas) # Wilcoxon binned histogram CUDA module add_nb_cuda_module(_wilcoxon_binned_cuda src/rapids_singlecell/_cuda/wilcoxon_binned/wilcoxon_binned.cu) - add_nb_cuda_module(_wilcoxon_streaming_cuda src/rapids_singlecell/_cuda/wilcoxon_streaming/wilcoxon_streaming.cu) + if(rmm_FOUND) + target_link_libraries(_wilcoxon_ovr_cuda PRIVATE rmm::rmm) + target_link_libraries(_wilcoxon_ovo_cuda PRIVATE rmm::rmm) + endif() endif() diff --git a/pyproject.toml b/pyproject.toml index c38e1d00..7449bacf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,8 +32,22 @@ dependencies = [ ] [project.optional-dependencies] -rapids-cu13 = [ "cupy-cuda13x", "cudf-cu13>=25.10", "cuml-cu13>=25.10", "cugraph-cu13>=25.10", "cuvs-cu13>=25.10" ] -rapids-cu12 = [ "cupy-cuda12x", "cudf-cu12>=25.10", "cuml-cu12>=25.10", "cugraph-cu12>=25.10", "cuvs-cu12>=25.10" ] +rapids-cu13 = [ + "cupy-cuda13x", + "librmm-cu13>=25.10", + "cudf-cu13>=25.10", + "cuml-cu13>=25.10", + "cugraph-cu13>=25.10", + "cuvs-cu13>=25.10", +] +rapids-cu12 = [ + "cupy-cuda12x", + "librmm-cu12>=25.10", + "cudf-cu12>=25.10", + "cuml-cu12>=25.10", + "cugraph-cu12>=25.10", + "cuvs-cu12>=25.10", +] doc = [ "sphinx>=4.5.0", diff --git a/src/rapids_singlecell/_cuda/__init__.py b/src/rapids_singlecell/_cuda/__init__.py index 35e82a0d..b11f342a 100644 --- a/src/rapids_singlecell/_cuda/__init__.py +++ b/src/rapids_singlecell/_cuda/__init__.py @@ -13,6 +13,18 @@ import importlib +# Pre-load librmm.so + deps so the dynamic linker can resolve them when +# our nanobind extensions (which link rmm) are imported. This is the same +# pattern used by cuml, cuvs, and other RAPIDS packages. +try: + import librmm + + librmm.load_library() +except (ImportError, OSError): + pass + +_RMM_MODULES = {"_wilcoxon_ovo_cuda", "_wilcoxon_ovr_cuda"} + __all__ = [ "_aggr_cuda", "_aucell_cuda", @@ -44,7 +56,8 @@ "_sparse2dense_cuda", "_spca_cuda", "_wilcoxon_binned_cuda", - "_wilcoxon_cuda", + "_wilcoxon_ovo_cuda", + "_wilcoxon_ovr_cuda", ] diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh index d1583500..fac8816a 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -3,41 +3,154 @@ #include // ============================================================================ -// CSR → dense F-order extraction +// Warp reduction helper (sum doubles across block via warp_buf) // ============================================================================ -__global__ void csr_extract_dense_kernel(const double* __restrict__ data, - const int* __restrict__ indices, - const int* __restrict__ indptr, - const int* __restrict__ row_ids, - double* __restrict__ out, int n_target, - int col_start, int col_stop) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n_target) return; - - int row = row_ids[tid]; - int rs = indptr[row]; - int re = indptr[row + 1]; - - int lo = rs, hi = re; - while (lo < hi) { - int m = (lo + hi) >> 1; - if (indices[m] < col_start) - lo = m + 1; - else - hi = m; +__device__ __forceinline__ double block_reduce_sum(double val, + double* warp_buf) { +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = val; + __syncthreads(); + if (threadIdx.x < 32) { + double v2 = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v2 += __shfl_down_sync(0xffffffff, v2, off); + return v2; // only lane 0 of warp 0 has the final result + } + return 0.0; +} + +// ============================================================================ +// Parallel tie correction — all threads collaborate. +// +// For each unique value in the combined sorted (ref, grp) arrays, accumulate +// t^3 - t where t = count of that value. Uses two passes: +// 1. Iterate unique values in ref_col, count in both arrays. +// 2. Iterate unique values in grp_col that do NOT appear in ref_col. +// +// Incremental binary search bounds exploit monotonicity within each thread's +// stride to reduce total search work. +// +// Caller must __syncthreads() before calling. warp_buf is reused for +// reduction (32 doubles, shared memory). +// ============================================================================ + +__device__ __forceinline__ void compute_tie_correction_parallel( + const float* ref_col, int n_ref, const float* grp_col, int n_grp, + double* warp_buf, double* out) { + double local_tie = 0.0; + + // Pass 1: unique values in ref_col + int grp_lb = 0, grp_ub = 0; + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + + // Count in ref: upper_bound from i+1 + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_ref = lo - i; + + // Count in grp: incremental lower/upper bound + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (grp_col[m] < v) + lo = m + 1; + else + hi = m; + } + int lb = lo; + grp_lb = lb; + + lo = (grp_ub > lb) ? grp_ub : lb; + hi = n_grp; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_grp = lo - lb; + grp_ub = lo; + + int cnt = cnt_ref + cnt_grp; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + // Pass 2: unique values in grp_col that are absent from ref_col + int ref_lb = 0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + if (i == 0 || grp_col[i] != grp_col[i - 1]) { + float v = grp_col[i]; + + // Incremental lower_bound in ref + int lo = ref_lb, hi = n_ref; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + ref_lb = lo; + + if (lo >= n_ref || ref_col[lo] != v) { + // Value not in ref — count in grp only (upper_bound from i+1) + lo = i + 1; + hi = n_grp; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt = lo - i; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } } - for (int p = lo; p < re; ++p) { - int c = indices[p]; - if (c >= col_stop) break; - out[(long long)(c - col_start) * n_target + tid] = data[p]; + // Block-wide reduction + double tie_sum = block_reduce_sum(local_tie, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + *out = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; } } // ============================================================================ // Batched rank sums — pre-sorted (binary search, no shared memory sort) // Used by the OVO streaming pipeline in wilcoxon_streaming.cu. +// +// Incremental binary search: each thread carries forward lower/upper bound +// positions across loop iterations, exploiting the monotonicity of the +// sorted grp_col values within each thread's stride. // ============================================================================ __global__ void batched_rank_sums_presorted_kernel( @@ -64,12 +177,17 @@ __global__ void batched_rank_sums_presorted_kernel( const float* ref_col = ref_sorted + (long long)col * n_ref; const float* grp_col = grp_sorted + (long long)col * n_all_grp + g_start; + // Incremental binary search bounds (advance monotonically per thread) + int ref_lb = 0, ref_ub = 0; + int grp_lb = 0, grp_ub = 0; double local_sum = 0.0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { - double v = grp_col[i]; + float v = grp_col[i]; int lo, hi; - lo = 0; + // Lower bound in ref (from ref_lb) + lo = ref_lb; hi = n_ref; while (lo < hi) { int m = (lo + hi) >> 1; @@ -79,7 +197,10 @@ __global__ void batched_rank_sums_presorted_kernel( hi = m; } int n_lt_ref = lo; - lo = n_lt_ref; + ref_lb = n_lt_ref; + + // Upper bound in ref (from max(ref_ub, n_lt_ref)) + lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; hi = n_ref; while (lo < hi) { int m = (lo + hi) >> 1; @@ -89,7 +210,10 @@ __global__ void batched_rank_sums_presorted_kernel( hi = m; } int n_eq_ref = lo - n_lt_ref; - lo = 0; + ref_ub = lo; + + // Lower bound in grp (from grp_lb) + lo = grp_lb; hi = n_grp; while (lo < hi) { int m = (lo + hi) >> 1; @@ -99,7 +223,10 @@ __global__ void batched_rank_sums_presorted_kernel( hi = m; } int n_lt_grp = lo; - lo = n_lt_grp; + grp_lb = n_lt_grp; + + // Upper bound in grp (from max(grp_ub, n_lt_grp)) + lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; hi = n_grp; while (lo < hi) { int m = (lo + hi) >> 1; @@ -109,127 +236,157 @@ __global__ void batched_rank_sums_presorted_kernel( hi = m; } int n_eq_grp = lo - n_lt_grp; + grp_ub = lo; local_sum += (double)(n_lt_ref + n_lt_grp) + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; } __shared__ double warp_buf[32]; -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - local_sum += __shfl_down_sync(0xffffffff, local_sum, off); - int lane = threadIdx.x & 31; - int wid = threadIdx.x >> 5; - if (lane == 0) warp_buf[wid] = local_sum; - __syncthreads(); - if (threadIdx.x < 32) { - double val = (threadIdx.x < ((blockDim.x + 31) >> 5)) - ? warp_buf[threadIdx.x] - : 0.0; -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - val += __shfl_down_sync(0xffffffff, val, off); - if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = val; - } + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; if (!compute_tie_corr) return; __syncthreads(); - if (threadIdx.x == 0) { - int ri = 0, gi = 0; - double tie_sum = 0.0; - while (ri < n_ref || gi < n_grp) { - double v; - if (gi >= n_grp) - v = ref_col[ri]; - else if (ri >= n_ref) - v = grp_col[gi]; - else - v = (ref_col[ri] <= grp_col[gi]) ? ref_col[ri] : grp_col[gi]; - int cnt = 0; - while (ri < n_ref && ref_col[ri] == v) { - ++ri; - ++cnt; - } - while (gi < n_grp && grp_col[gi] == v) { - ++gi; - ++cnt; - } - if (cnt > 1) { - double t = (double)cnt; - tie_sum += t * t * t - t; - } - } - int n = n_ref + n_grp; - double dn = (double)n; - double denom = dn * dn * dn - dn; - tie_corr[grp * n_cols + col] = - (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; - } + compute_tie_correction_parallel(ref_col, n_ref, grp_col, n_grp, warp_buf, + &tie_corr[grp * n_cols + col]); } // ============================================================================ -// Grouped statistics: sum, sum-of-squares, nnz per group +// Tier 1 fused kernel: smem bitonic sort + binary search rank sums +// For small groups (< ~2K cells). No CUB, no global memory sort buffers. +// Grid: (n_cols, n_groups), Block: min(padded_grp_size, 512) +// Shared memory: padded_grp_size floats + 32 doubles (warp reduction) // ============================================================================ -__global__ void grouped_stats_kernel( - const double* __restrict__ data, // F-order (n_all_rows, n_cols) - const int* __restrict__ grp_offsets, // (n_groups + 1,) - double* __restrict__ sums, // (n_groups, n_cols) row-major - double* __restrict__ sq_sums, // (n_groups, n_cols) row-major - double* __restrict__ nnz_counts, // (n_groups, n_cols) row-major - int n_all_rows, int n_cols, int n_groups, bool compute_nnz) { +__global__ void ovo_fused_sort_rank_kernel( + const float* __restrict__ ref_sorted, // F-order (n_ref, n_cols) sorted + const float* __restrict__ grp_dense, // F-order (n_all_grp, n_cols) + // unsorted + const int* __restrict__ grp_offsets, // (n_groups + 1,) + double* __restrict__ rank_sums, // (n_groups, n_cols) row-major + double* __restrict__ tie_corr, // (n_groups, n_cols) row-major + int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, + int padded_grp_size) { int col = blockIdx.x; - if (col >= n_cols) return; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; - extern __shared__ double smem[]; - double* s_sum = smem; - double* s_sq = smem + n_groups; - double* s_nnz = smem + 2 * n_groups; + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; - for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { - s_sum[g] = 0.0; - s_sq[g] = 0.0; - s_nnz[g] = 0.0; + if (n_grp == 0) { + if (threadIdx.x == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; } + + // Shared memory: [padded_grp_size floats | 32 doubles for warp reduction] + extern __shared__ char smem_raw[]; + float* grp_smem = (float*)smem_raw; + double* warp_buf = (double*)(smem_raw + padded_grp_size * sizeof(float)); + + // Load group data into shared memory, pad with +INF + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) + grp_smem[i] = grp_col[i]; + for (int i = n_grp + threadIdx.x; i < padded_grp_size; i += blockDim.x) + grp_smem[i] = __int_as_float(0x7f800000); // +INF __syncthreads(); - const double* col_data = data + (long long)col * n_all_rows; + // Bitonic sort in shared memory + for (int k = 2; k <= padded_grp_size; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + for (int i = threadIdx.x; i < padded_grp_size; i += blockDim.x) { + int ixj = i ^ j; + if (ixj > i) { + bool asc = ((i & k) == 0); + float a = grp_smem[i], b = grp_smem[ixj]; + if (asc ? (a > b) : (a < b)) { + grp_smem[i] = b; + grp_smem[ixj] = a; + } + } + } + __syncthreads(); + } + } - for (int g = 0; g < n_groups; g++) { - int g_start = grp_offsets[g]; - int g_end = grp_offsets[g + 1]; + // Binary search each sorted grp element against sorted ref + // Incremental bounds: values are monotonic within each thread's stride + const float* ref_col = ref_sorted + (long long)col * n_ref; + int ref_lb = 0, ref_ub = 0; + int grp_lb = 0, grp_ub = 0; + double local_sum = 0.0; - double local_sum = 0.0; - double local_sq = 0.0; - double local_nnz = 0.0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_smem[i]; + int lo, hi; - for (int i = g_start + threadIdx.x; i < g_end; i += blockDim.x) { - double v = col_data[i]; - local_sum += v; - local_sq += v * v; - if (compute_nnz && v != 0.0) local_nnz += 1.0; + lo = ref_lb; + hi = n_ref; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; } + int n_lt_ref = lo; + ref_lb = n_lt_ref; -#pragma unroll - for (int off = 16; off > 0; off >>= 1) { - local_sum += __shfl_down_sync(0xffffffff, local_sum, off); - local_sq += __shfl_down_sync(0xffffffff, local_sq, off); - if (compute_nnz) - local_nnz += __shfl_down_sync(0xffffffff, local_nnz, off); + lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; + hi = n_ref; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; } + int n_eq_ref = lo - n_lt_ref; + ref_ub = lo; - if ((threadIdx.x & 31) == 0) { - atomicAdd(&s_sum[g], local_sum); - atomicAdd(&s_sq[g], local_sq); - if (compute_nnz) atomicAdd(&s_nnz[g], local_nnz); + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (grp_smem[m] < v) + lo = m + 1; + else + hi = m; } - __syncthreads(); - } + int n_lt_grp = lo; + grp_lb = n_lt_grp; - for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { - sums[(long long)g * n_cols + col] = s_sum[g]; - sq_sums[(long long)g * n_cols + col] = s_sq[g]; - if (compute_nnz) nnz_counts[(long long)g * n_cols + col] = s_nnz[g]; + lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; + hi = n_grp; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (grp_smem[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + grp_ub = lo; + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; } + + // Block reduction → write rank_sums + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + // Parallel tie correction (grp_smem is sorted shared memory) + compute_tie_correction_parallel(ref_col, n_ref, grp_smem, n_grp, warp_buf, + &tie_corr[grp * n_cols + col]); } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu deleted file mode 100644 index e511c895..00000000 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ /dev/null @@ -1,92 +0,0 @@ -#include - -#include - -#include "../nb_types.h" -#include "kernels_wilcoxon_ovo.cuh" - -using namespace nb::literals; - -constexpr int WARP_SIZE = 32; -constexpr int MAX_THREADS_PER_BLOCK = 512; - -static inline int round_up_to_warp(int n) { - int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; -} - -static size_t get_seg_sort_temp_bytes(int n_items, int n_segments) { - size_t bytes = 0; - auto* dk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys(nullptr, bytes, dk, dk, n_items, - n_segments, doff, doff + 1, 0, 32); - return bytes; -} - -template -void register_bindings(nb::module_& m) { - m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; - - m.def("get_seg_sort_temp_bytes", &get_seg_sort_temp_bytes, "n_items"_a, - "n_segments"_a); - - m.def( - "segmented_sort", - [](gpu_array_c keys_in, - gpu_array_c keys_out, - gpu_array_c offsets, - gpu_array_c cub_temp, int n_items, int n_segments, - std::uintptr_t stream) { - size_t temp_bytes = cub_temp.size(); - cub::DeviceSegmentedRadixSort::SortKeys( - cub_temp.data(), temp_bytes, keys_in.data(), keys_out.data(), - n_items, n_segments, offsets.data(), offsets.data() + 1, 0, 32, - (cudaStream_t)stream); - CUDA_CHECK_LAST_ERROR(DeviceSegmentedRadixSort); - }, - "keys_in"_a, "keys_out"_a, "offsets"_a, "cub_temp"_a, nb::kw_only(), - "n_items"_a, "n_segments"_a, "stream"_a = 0); - - m.def( - "csr_extract_dense", - [](gpu_array_c data, - gpu_array_c indices, - gpu_array_c indptr, - gpu_array_c row_ids, - gpu_array_f out, int n_target, int col_start, - int col_stop, std::uintptr_t stream) { - int tpb = round_up_to_warp(n_target); - int blocks = (n_target + tpb - 1) / tpb; - csr_extract_dense_kernel<<>>( - data.data(), indices.data(), indptr.data(), row_ids.data(), - out.data(), n_target, col_start, col_stop); - CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); - }, - "data"_a, "indices"_a, "indptr"_a, "row_ids"_a, "out"_a, nb::kw_only(), - "n_target"_a, "col_start"_a, "col_stop"_a, "stream"_a = 0); - - m.def( - "grouped_stats", - [](gpu_array_f data, - gpu_array_c grp_offsets, - gpu_array_c sums, - gpu_array_c sq_sums, - gpu_array_c nnz_counts, int n_all_rows, int n_cols, - int n_groups, bool compute_nnz, std::uintptr_t stream) { - constexpr int THREADS = 256; - int smem = 3 * n_groups * sizeof(double); - grouped_stats_kernel<<>>( - data.data(), grp_offsets.data(), sums.data(), sq_sums.data(), - nnz_counts.data(), n_all_rows, n_cols, n_groups, compute_nnz); - CUDA_CHECK_LAST_ERROR(grouped_stats_kernel); - }, - "data"_a, "grp_offsets"_a, "sums"_a, "sq_sums"_a, "nnz_counts"_a, - nb::kw_only(), "n_all_rows"_a, "n_cols"_a, "n_groups"_a, - "compute_nnz"_a, "stream"_a = 0); -} - -NB_MODULE(_wilcoxon_cuda, m) { - REGISTER_GPU_BINDINGS(register_bindings, m); -} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh new file mode 100644 index 00000000..98d26971 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh @@ -0,0 +1,87 @@ +#pragma once + +#include +#include + +#include +#include +#if __has_include() +#include // rmm >= 26.02 +#else +#include // rmm 25.x +#endif + +constexpr int WARP_SIZE = 32; +constexpr int MAX_THREADS_PER_BLOCK = 512; +constexpr int N_STREAMS = 4; +constexpr int SUB_BATCH_COLS = 32; +constexpr int BEGIN_BIT = 0; +constexpr int END_BIT = 32; + +// --------------------------------------------------------------------------- +// RMM pool helper — allocate GPU buffers through the current RMM memory +// resource. Buffers are stored in a vector and freed (RAII) when the vector +// is destroyed. +// --------------------------------------------------------------------------- +struct RmmPool { + std::vector bufs; + rmm::device_async_resource_ref mr; + + RmmPool() : mr(rmm::mr::get_current_device_resource()) { + } + + template + T* alloc(size_t count) { + bufs.emplace_back(count * sizeof(T), rmm::cuda_stream_default, mr); + return static_cast(bufs.back().data()); + } +}; + +static inline int round_up_to_warp(int n) { + int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; +} + +/** Upload linear segment offsets [0, stride, 2*stride, ...] to device. + * Uses synchronous copy — the buffer is small (a few hundred bytes). */ +static inline void upload_linear_offsets(int* d_offsets, int n_segments, + int stride, cudaStream_t stream) { + std::vector h(n_segments + 1); + for (int i = 0; i <= n_segments; i++) h[i] = i * stride; + cudaMemcpy(d_offsets, h.data(), (n_segments + 1) * sizeof(int), + cudaMemcpyHostToDevice); +} + +// ============================================================================ +// CSR → dense F-order extraction (templated on data type) +// ============================================================================ + +template +__global__ void csr_extract_dense_kernel(const T* __restrict__ data, + const int* __restrict__ indices, + const int* __restrict__ indptr, + const int* __restrict__ row_ids, + T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_target) return; + + int row = row_ids[tid]; + int rs = indptr[row]; + int re = indptr[row + 1]; + + int lo = rs, hi = re; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + + for (int p = lo; p < re; ++p) { + int c = indices[p]; + if (c >= col_stop) break; + out[(long long)(c - col_start) * n_target + tid] = data[p]; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu new file mode 100644 index 00000000..1f64ed53 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu @@ -0,0 +1,1484 @@ +#include +#include + +#include + +#include "../nb_types.h" +#include "wilcoxon_common.cuh" +#include "kernels_wilcoxon_ovo.cuh" + +using namespace nb::literals; + +/** + * Build segment offsets for CUB segmented sort of group data within a + * sub-batch. offset[c * n_groups + g] = c * n_all_grp + grp_offsets[g]. + * One thread per entry. + */ +__global__ void build_seg_offsets_kernel( + const int* __restrict__ grp_offsets, // (n_groups + 1,) + int* __restrict__ out, // (sb_cols * n_groups + 1,) + int n_all_grp, int n_groups, int sb_cols) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = sb_cols * n_groups + 1; + if (idx >= total) return; + if (idx == sb_cols * n_groups) { + out[idx] = sb_cols * n_all_grp; + } else { + int c = idx / n_groups; + int g = idx % n_groups; + out[idx] = c * n_all_grp + grp_offsets[g]; + } +} + +/** + * Extract specific rows from CSC into dense F-order, using a row lookup map. + * row_map[original_row] = output_row_index (or -1 to skip). + * One block per column, threads scatter matching nonzeros. + * Output must be pre-zeroed. + */ +__global__ void csc_extract_mapped_kernel(const float* __restrict__ data, + const int* __restrict__ indices, + const int* __restrict__ indptr, + const int* __restrict__ row_map, + float* __restrict__ out, int n_target, + int col_start) { + int col_local = blockIdx.x; + int col = col_start + col_local; + + int start = indptr[col]; + int end = indptr[col + 1]; + + for (int p = start + threadIdx.x; p < end; p += blockDim.x) { + int out_row = row_map[indices[p]]; + if (out_row >= 0) { + out[(long long)col_local * n_target + out_row] = data[p]; + } + } +} + +static size_t get_seg_sort_temp_bytes(int n_items, int n_segments) { + size_t bytes = 0; + auto* dk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys(nullptr, bytes, dk, dk, n_items, + n_segments, doff, doff + 1, 0, 32); + return bytes; +} + +/** + * Streaming OVO pipeline. + * + * Takes pre-sorted reference (float32 F-order), unsorted group data (float32 + * F-order with group offsets), and produces rank_sums + tie_corr. + * + * For each sub-batch of columns: + * 1. CUB segmented sort-keys of group data (one segment per group per col) + * 2. batched_rank_sums_presorted_kernel (binary search in sorted ref) + */ +static void ovo_streaming_impl(const float* ref_sorted, const float* grp_data, + const int* grp_offsets, double* rank_sums, + double* tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + int max_n_seg = n_groups * sub_batch_cols; + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys(nullptr, cub_temp_bytes, fk, fk, + (int)sub_grp_items, max_n_seg, + doff, doff + 1, 0, 32); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Allocate per-stream buffers via RMM pool + RmmPool pool; + struct StreamBuf { + float* grp_sorted; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + bufs[s].seg_offsets = pool.alloc(max_n_seg + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + } + + // Compute max individual group size for accurate thread count + std::vector h_off(n_groups + 1); + cudaMemcpy(h_off.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + int max_grp_size = 0; + for (int g = 0; g < n_groups; g++) { + int sz = h_off[g + 1] - h_off[g]; + if (sz > max_grp_size) max_grp_size = sz; + } + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_n_seg = n_groups * sb_cols; + int sb_grp_items = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // Build segment offsets on device + { + int total = sb_n_seg + 1; + int blk = (total + 255) / 256; + build_seg_offsets_kernel<<>>( + grp_offsets, buf.seg_offsets, n_all_grp, n_groups, sb_cols); + } + + // Sort group data for this sub-batch + const float* grp_in = grp_data + (long long)col * n_all_grp; + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, grp_in, buf.grp_sorted, sb_grp_items, sb_n_seg, + buf.seg_offsets, buf.seg_offsets + 1, 0, 32, stream); + + // Rank sums: binary search sorted ref for each group element + const float* ref_sub = ref_sorted + (long long)col * n_ref; + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr); + + // Scatter sub-batch results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * CSR-direct OVO streaming pipeline. + * + * One C++ call does everything: extract rows from CSR → sort → rank. + * Per sub-batch of columns: + * 1. Extract ref rows → dense f32 → CUB sort + * 2. Extract grp rows → dense f32 → CUB sort (segmented by group) + * 3. Binary search rank sums + * Only ~(n_ref + n_all_grp) × sub_batch × 4B on GPU at a time. + */ +static void ovo_streaming_csr_impl( + const float* csr_data, const int* csr_indices, const int* csr_indptr, + const int* ref_row_ids, const int* grp_row_ids, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch: read group offsets to determine max group size ---- + constexpr int TIER1_THRESHOLD = 2500; + std::vector h_offsets(n_groups + 1); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + int max_grp_size = 0; + for (int g = 0; g < n_groups; g++) { + int sz = h_offsets[g + 1] - h_offsets[g]; + if (sz > max_grp_size) max_grp_size = sz; + } + bool use_tier1 = (max_grp_size <= TIER1_THRESHOLD); + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp for ref sort (always needed) + grp sort (Tier 3 only) + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + + if (!use_tier1) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + // Tier 1 precomputation + int padded_grp_size = 0; + int tier1_tpb = 0; + size_t tier1_smem = 0; + if (use_tier1) { + padded_grp_size = 1; + while (padded_grp_size < max_grp_size) padded_grp_size <<= 1; + tier1_tpb = std::min(padded_grp_size, MAX_THREADS_PER_BLOCK); + tier1_smem = padded_grp_size * sizeof(float) + 32 * sizeof(double); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Allocate per-stream buffers via RMM pool + RmmPool pool; + struct StreamBuf { + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (!use_tier1) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_seg = n_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_seg + 1); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + } + } + + int tpb_extract = round_up_to_warp(std::max(n_ref, n_all_grp)); + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = n_ref * sb_cols; + int sb_grp_items_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- Extract + sort ref (always CUB) ---- + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_items_actual * sizeof(float), + stream); + { + int blk = (n_ref + tpb_extract - 1) / tpb_extract; + csr_extract_dense_kernel<<>>( + csr_data, csr_indices, csr_indptr, ref_row_ids, buf.ref_dense, + n_ref, col, col + sb_cols); + } + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Extract grp rows ---- + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), + stream); + { + int blk = (n_all_grp + tpb_extract - 1) / tpb_extract; + csr_extract_dense_kernel<<>>( + csr_data, csr_indices, csr_indptr, grp_row_ids, buf.grp_dense, + n_all_grp, col, col + sb_cols); + } + + if (use_tier1) { + // ---- Tier 1: fused smem sort + binary search ---- + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size); + } else { + // ---- Tier 3: CUB segmented sort + binary search ---- + int sb_grp_seg = n_groups * sb_cols; + { + int total = sb_grp_seg + 1; + int blk = (total + 255) / 256; + build_seg_offsets_kernel<<>>( + grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, + sb_cols); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, grp_offsets, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr); + } + } + + // ---- Scatter to global output ---- + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * CSC-direct OVO streaming pipeline. + * + * Like the CSR variant but extracts rows via a row-lookup map, avoiding + * CSC→CSR conversion. row_map_ref[row] = output index in ref block (-1 if + * not a ref row); likewise for row_map_grp. + */ +static void ovo_streaming_csc_impl( + const float* csc_data, const int* csc_indices, const int* csc_indptr, + const int* ref_row_map, const int* grp_row_map, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch ---- + constexpr int TIER1_THRESHOLD = 2500; + std::vector h_offsets(n_groups + 1); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + int max_grp_size = 0; + for (int g = 0; g < n_groups; g++) { + int sz = h_offsets[g + 1] - h_offsets[g]; + if (sz > max_grp_size) max_grp_size = sz; + } + bool use_tier1 = (max_grp_size <= TIER1_THRESHOLD); + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (!use_tier1) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + // Tier 1 precomputation + int padded_grp_size = 0; + int tier1_tpb = 0; + size_t tier1_smem = 0; + if (use_tier1) { + padded_grp_size = 1; + while (padded_grp_size < max_grp_size) padded_grp_size <<= 1; + tier1_tpb = std::min(padded_grp_size, MAX_THREADS_PER_BLOCK); + tier1_smem = padded_grp_size * sizeof(float) + 32 * sizeof(double); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + struct StreamBuf { + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (!use_tier1) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = n_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg + 1); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = n_ref * sb_cols; + int sb_grp_items_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- Extract ref from CSC via row_map, then sort ---- + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_items_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + csc_data, csc_indices, csc_indptr, ref_row_map, buf.ref_dense, + n_ref, col); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Extract grp from CSC via row_map ---- + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + csc_data, csc_indices, csc_indptr, grp_row_map, buf.grp_dense, + n_all_grp, col); + + if (use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size); + } else { + int sb_grp_seg = n_groups * sb_cols; + { + int total = sb_grp_seg + 1; + int blk = (total + 255) / 256; + build_seg_offsets_kernel<<>>( + grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, + sb_cols); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, grp_offsets, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr); + } + } + + // ---- Scatter to global output ---- + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * Host-streaming CSC OVO pipeline. + * + * CSC arrays live on host. Only the sparse data for each sub-batch of + * columns is transferred to GPU. Row maps + group offsets are uploaded once. + * Results are written back to host per sub-batch. + */ +static void ovo_streaming_csc_host_impl( + const float* h_data, const int* h_indices, const int* h_indptr, + const int* h_ref_row_map, const int* h_grp_row_map, + const int* h_grp_offsets, double* h_rank_sums, double* h_tie_corr, + int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch from host offsets ---- + constexpr int TIER1_THRESHOLD = 2500; + int max_grp_size = 0; + for (int g = 0; g < n_groups; g++) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (sz > max_grp_size) max_grp_size = sz; + } + bool use_tier1 = (max_grp_size <= TIER1_THRESHOLD); + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (!use_tier1) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + int padded_grp_size = 0; + int tier1_tpb = 0; + size_t tier1_smem = 0; + if (use_tier1) { + padded_grp_size = 1; + while (padded_grp_size < max_grp_size) padded_grp_size <<= 1; + tier1_tpb = std::min(padded_grp_size, MAX_THREADS_PER_BLOCK); + tier1_smem = padded_grp_size * sizeof(float) + 32 * sizeof(double); + } + + // Max nnz across any sub-batch for sparse transfer buffer sizing + size_t max_nnz = 0; + for (int c = 0; c < n_cols; c += sub_batch_cols) { + int sb = std::min(sub_batch_cols, n_cols - c); + size_t nnz = (size_t)(h_indptr[c + sb] - h_indptr[c]); + if (nnz > max_nnz) max_nnz = nnz; + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + + // GPU copies of row maps + group offsets (uploaded once) + int* d_ref_row_map = pool.alloc(n_rows); + int* d_grp_row_map = pool.alloc(n_rows); + int* d_grp_offsets = pool.alloc(n_groups + 1); + cudaMemcpy(d_ref_row_map, h_ref_row_map, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_map, h_grp_row_map, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyHostToDevice); + + struct StreamBuf { + float* d_sparse_data; + int* d_sparse_indices; + int* d_indptr; + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_sparse_data = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_indptr = pool.alloc(sub_batch_cols + 1); + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (!use_tier1) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = n_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg + 1); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + // Pin host memory for async transfers + cudaHostRegister(const_cast(h_data), + (size_t)h_indptr[n_cols] * sizeof(float), 0); + cudaHostRegister(const_cast(h_indices), + (size_t)h_indptr[n_cols] * sizeof(int), 0); + cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), + 0); + cudaHostRegister(h_tie_corr, (size_t)n_groups * n_cols * sizeof(double), 0); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_actual = n_ref * sb_cols; + int sb_grp_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- H2D: sparse data for this column range ---- + int ptr_start = h_indptr[col]; + int ptr_end = h_indptr[col + sb_cols]; + size_t nnz = (size_t)(ptr_end - ptr_start); + cudaMemcpyAsync(buf.d_sparse_data, h_data + ptr_start, + nnz * sizeof(float), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, + nnz * sizeof(int), cudaMemcpyHostToDevice, stream); + { + std::vector h_adj(sb_cols + 1); + for (int i = 0; i <= sb_cols; i++) + h_adj[i] = h_indptr[col + i] - ptr_start; + cudaMemcpy(buf.d_indptr, h_adj.data(), (sb_cols + 1) * sizeof(int), + cudaMemcpyHostToDevice); + } + + // ---- Extract ref from CSC via row_map, sort ---- + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + buf.d_sparse_data, buf.d_sparse_indices, buf.d_indptr, + d_ref_row_map, buf.ref_dense, n_ref, 0); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Extract grp from CSC via row_map ---- + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + buf.d_sparse_data, buf.d_sparse_indices, buf.d_indptr, + d_grp_row_map, buf.grp_dense, n_all_grp, 0); + + // ---- Tier dispatch: sort grp + rank ---- + if (use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size); + } else { + int sb_grp_seg = n_groups * sb_cols; + { + int total = sb_grp_seg + 1; + int blk = (total + 255) / 256; + build_seg_offsets_kernel<<>>( + d_grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, + sb_cols); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, d_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr); + } + } + + // ---- D2H: scatter results ---- + cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(h_tie_corr + col, n_cols * sizeof(double), + buf.d_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + + cudaHostUnregister(const_cast(h_data)); + cudaHostUnregister(const_cast(h_indices)); + cudaHostUnregister(h_rank_sums); + cudaHostUnregister(h_tie_corr); + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * Host CSR OVO pipeline — preload reference, stream perturbations. + * + * Two-phase approach: + * Phase 1: Transfer CSR to GPU, extract ref rows for ALL columns, sort once. + * Phase 2: For each column sub-batch, extract only grp rows, sort, rank + * against the pre-sorted reference. + * + * The reference is sorted once (not per sub-batch), saving ~50% of the + * per-sub-batch extraction + sort work. + */ +static void ovo_streaming_csr_host_impl( + const float* h_data, const int* h_indices, const int* h_indptr, + const int* h_ref_row_ids, const int* h_grp_row_ids, + const int* h_grp_offsets, double* h_rank_sums, double* h_tie_corr, + int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, int nnz, + bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch from host offsets ---- + constexpr int TIER1_THRESHOLD = 2500; + int max_grp_size = 0; + for (int g = 0; g < n_groups; g++) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (sz > max_grp_size) max_grp_size = sz; + } + bool use_tier1 = (max_grp_size <= TIER1_THRESHOLD); + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp — sized for the larger of ref (full) or grp (sub-batch) + size_t ref_total = (size_t)n_ref * n_cols; + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys(nullptr, cub_ref_bytes, fk, fk, + (int)ref_total, n_cols, doff, + doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (!use_tier1) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + int padded_grp_size = 0; + int tier1_tpb = 0; + size_t tier1_smem = 0; + if (use_tier1) { + padded_grp_size = 1; + while (padded_grp_size < max_grp_size) padded_grp_size <<= 1; + tier1_tpb = std::min(padded_grp_size, MAX_THREADS_PER_BLOCK); + tier1_smem = padded_grp_size * sizeof(float) + 32 * sizeof(double); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + + // ---- Phase 1: Transfer CSR, extract + sort reference (all columns) ---- + float* d_data = pool.alloc(nnz); + int* d_indices = pool.alloc(nnz); + int* d_indptr = pool.alloc(n_rows + 1); + int* d_ref_row_ids = pool.alloc(n_ref); + int* d_grp_row_ids = pool.alloc(n_all_grp); + int* d_grp_offsets = pool.alloc(n_groups + 1); + + cudaHostRegister(const_cast(h_data), (size_t)nnz * sizeof(float), + 0); + cudaHostRegister(const_cast(h_indices), (size_t)nnz * sizeof(int), 0); + cudaMemcpyAsync(d_data, h_data, (size_t)nnz * sizeof(float), + cudaMemcpyHostToDevice, streams[0]); + cudaMemcpyAsync(d_indices, h_indices, (size_t)nnz * sizeof(int), + cudaMemcpyHostToDevice, streams[0]); + cudaMemcpy(d_indptr, h_indptr, (n_rows + 1) * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_ref_row_ids, h_ref_row_ids, n_ref * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_ids, h_grp_row_ids, n_all_grp * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyHostToDevice); + cudaStreamSynchronize(streams[0]); + + // Extract ref for ALL columns, sort once + float* ref_dense = pool.alloc(ref_total); + float* ref_sorted = pool.alloc(ref_total); + cudaMemset(ref_dense, 0, ref_total * sizeof(float)); + { + int tpb = round_up_to_warp(n_ref); + int blk = (n_ref + tpb - 1) / tpb; + csr_extract_dense_kernel<<>>(d_data, d_indices, d_indptr, + d_ref_row_ids, ref_dense, n_ref, + 0, n_cols); + } + { + int* ref_seg = pool.alloc(n_cols + 1); + upload_linear_offsets(ref_seg, n_cols, n_ref, nullptr); + uint8_t* cub_tmp = pool.alloc(cub_ref_bytes); + size_t temp = cub_ref_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + cub_tmp, temp, ref_dense, ref_sorted, (int)ref_total, n_cols, + ref_seg, ref_seg + 1, BEGIN_BIT, END_BIT); + } + cudaDeviceSynchronize(); + + // ---- Phase 2: Stream grp sub-batches, rank against pre-sorted ref ---- + struct StreamBuf { + float* grp_dense; + float* grp_sorted; + int* grp_seg_offsets; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (!use_tier1) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = n_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg + 1); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + } + } + + int tpb_extract = round_up_to_warp(n_all_grp); + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), + 0); + cudaHostRegister(h_tie_corr, (size_t)n_groups * n_cols * sizeof(double), 0); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_grp_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // Extract grp only (ref already sorted) + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_actual * sizeof(float), + stream); + { + int blk = (n_all_grp + tpb_extract - 1) / tpb_extract; + csr_extract_dense_kernel<<>>( + d_data, d_indices, d_indptr, d_grp_row_ids, buf.grp_dense, + n_all_grp, col, col + sb_cols); + } + + // Rank against pre-sorted ref (just slice into ref_sorted) + const float* ref_sub = ref_sorted + (long long)col * n_ref; + if (use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + ref_sub, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size); + } else { + int sb_grp_seg = n_groups * sb_cols; + { + int total = sb_grp_seg + 1; + int blk = (total + 255) / 256; + build_seg_offsets_kernel<<>>( + d_grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, + sb_cols); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.grp_sorted, d_grp_offsets, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr); + } + } + + // D2H results + cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(h_tie_corr + col, n_cols * sizeof(double), + buf.d_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + + cudaHostUnregister(const_cast(h_data)); + cudaHostUnregister(const_cast(h_indices)); + cudaHostUnregister(h_rank_sums); + cudaHostUnregister(h_tie_corr); + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * Gather specific rows from a dense F-order block into a smaller dense block. + * Grid: (n_cols,), Block: 256. + * row_ids[i] = original row index → output row i. + */ +__global__ void dense_gather_rows_kernel(const float* __restrict__ in, + const int* __restrict__ row_ids, + float* __restrict__ out, int n_rows_in, + int n_target, int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + const float* in_col = in + (long long)col * n_rows_in; + float* out_col = out + (long long)col * n_target; + for (int i = threadIdx.x; i < n_target; i += blockDim.x) { + out_col[i] = in_col[row_ids[i]]; + } +} + +/** + * Host-streaming dense OVO pipeline. + * + * Dense F-order float32 lives on host. Sub-batches of columns are H2D + * transferred, then ref/grp rows are gathered, sorted, and ranked. + * Results D2H per sub-batch. + */ +static void ovo_streaming_dense_host_impl( + const float* h_block, const int* h_ref_row_ids, const int* h_grp_row_ids, + const int* h_grp_offsets, double* h_rank_sums, double* h_tie_corr, + int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch from host offsets ---- + constexpr int TIER1_THRESHOLD = 2500; + int max_grp_size = 0; + for (int g = 0; g < n_groups; g++) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (sz > max_grp_size) max_grp_size = sz; + } + bool use_tier1 = (max_grp_size <= TIER1_THRESHOLD); + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_dense = (size_t)n_rows * sub_batch_cols; + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (!use_tier1) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + int padded_grp_size = 0; + int tier1_tpb = 0; + size_t tier1_smem = 0; + if (use_tier1) { + padded_grp_size = 1; + while (padded_grp_size < max_grp_size) padded_grp_size <<= 1; + tier1_tpb = std::min(padded_grp_size, MAX_THREADS_PER_BLOCK); + tier1_smem = padded_grp_size * sizeof(float) + 32 * sizeof(double); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + + // GPU copies of row_ids + group offsets (uploaded once) + int* d_ref_row_ids = pool.alloc(n_ref); + int* d_grp_row_ids = pool.alloc(n_all_grp); + int* d_grp_offsets = pool.alloc(n_groups + 1); + cudaMemcpy(d_ref_row_ids, h_ref_row_ids, n_ref * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_ids, h_grp_row_ids, n_all_grp * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyHostToDevice); + + struct StreamBuf { + float* d_block; // H2D sub-batch (all rows) + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_block = pool.alloc(sub_dense); + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (!use_tier1) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = n_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg + 1); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + // Pin host memory + cudaHostRegister(const_cast(h_block), + (size_t)n_rows * n_cols * sizeof(float), 0); + cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), + 0); + cudaHostRegister(h_tie_corr, (size_t)n_groups * n_cols * sizeof(double), 0); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_dense = n_rows * sb_cols; + int sb_ref_actual = n_ref * sb_cols; + int sb_grp_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- H2D: dense column sub-batch (F-order, contiguous) ---- + cudaMemcpyAsync(buf.d_block, h_block + (long long)col * n_rows, + sb_dense * sizeof(float), cudaMemcpyHostToDevice, + stream); + + // ---- Gather ref rows, sort ---- + dense_gather_rows_kernel<<>>( + buf.d_block, d_ref_row_ids, buf.ref_dense, n_rows, n_ref, sb_cols); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Gather grp rows ---- + dense_gather_rows_kernel<<>>( + buf.d_block, d_grp_row_ids, buf.grp_dense, n_rows, n_all_grp, + sb_cols); + + // ---- Tier dispatch: sort grp + rank ---- + if (use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size); + } else { + int sb_grp_seg = n_groups * sb_cols; + { + int total = sb_grp_seg + 1; + int blk = (total + 255) / 256; + build_seg_offsets_kernel<<>>( + d_grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, + sb_cols); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, d_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr); + } + } + + // ---- D2H: scatter results ---- + cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(h_tie_corr + col, n_cols * sizeof(double), + buf.d_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToHost, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + + cudaHostUnregister(const_cast(h_block)); + cudaHostUnregister(h_rank_sums); + cudaHostUnregister(h_tie_corr); + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Nanobind module +// ============================================================================ + +template +void register_bindings(nb::module_& m) { + m.doc() = "CUDA kernels for Wilcoxon rank-sum test (OVO)"; + + // ---- Utility bindings (CUB sort, CSR extraction) ---- + + m.def("get_seg_sort_temp_bytes", &get_seg_sort_temp_bytes, "n_items"_a, + "n_segments"_a); + + m.def( + "segmented_sort", + [](gpu_array_c keys_in, + gpu_array_c keys_out, + gpu_array_c offsets, + gpu_array_c cub_temp, int n_items, int n_segments, + std::uintptr_t stream) { + size_t temp_bytes = cub_temp.size(); + cub::DeviceSegmentedRadixSort::SortKeys( + cub_temp.data(), temp_bytes, keys_in.data(), keys_out.data(), + n_items, n_segments, offsets.data(), offsets.data() + 1, 0, 32, + (cudaStream_t)stream); + CUDA_CHECK_LAST_ERROR(DeviceSegmentedRadixSort); + }, + "keys_in"_a, "keys_out"_a, "offsets"_a, "cub_temp"_a, nb::kw_only(), + "n_items"_a, "n_segments"_a, "stream"_a = 0); + + m.def( + "csr_extract_dense", + [](gpu_array_c data, + gpu_array_c indices, + gpu_array_c indptr, + gpu_array_c row_ids, + gpu_array_f out, int n_target, int col_start, + int col_stop, std::uintptr_t stream) { + int tpb = round_up_to_warp(n_target); + int blocks = (n_target + tpb - 1) / tpb; + csr_extract_dense_kernel<<>>( + data.data(), indices.data(), indptr.data(), row_ids.data(), + out.data(), n_target, col_start, col_stop); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + }, + "data"_a, "indices"_a, "indptr"_a, "row_ids"_a, "out"_a, nb::kw_only(), + "n_target"_a, "col_start"_a, "col_stop"_a, "stream"_a = 0); + + m.def( + "csr_extract_dense_f32", + [](gpu_array_c data, + gpu_array_c indices, + gpu_array_c indptr, + gpu_array_c row_ids, + gpu_array_f out, int n_target, int col_start, + int col_stop, std::uintptr_t stream) { + int tpb = round_up_to_warp(n_target); + int blocks = (n_target + tpb - 1) / tpb; + csr_extract_dense_kernel<<>>( + data.data(), indices.data(), indptr.data(), row_ids.data(), + out.data(), n_target, col_start, col_stop); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + }, + "data"_a, "indices"_a, "indptr"_a, "row_ids"_a, "out"_a, nb::kw_only(), + "n_target"_a, "col_start"_a, "col_stop"_a, "stream"_a = 0); + + // ---- Streaming pipelines ---- + + m.def( + "ovo_streaming_csr", + [](gpu_array_c csr_data, + gpu_array_c csr_indices, + gpu_array_c csr_indptr, + gpu_array_c ref_row_ids, + gpu_array_c grp_row_ids, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_csr_impl( + csr_data.data(), csr_indices.data(), csr_indptr.data(), + ref_row_ids.data(), grp_row_ids.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "ref_row_ids"_a, + "grp_row_ids"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovo_streaming_csc", + [](gpu_array_c csc_data, + gpu_array_c csc_indices, + gpu_array_c csc_indptr, + gpu_array_c ref_row_map, + gpu_array_c grp_row_map, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_csc_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + ref_row_map.data(), grp_row_map.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "ref_row_map"_a, + "grp_row_map"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovo_streaming", + [](gpu_array_f ref_sorted, + gpu_array_f grp_data, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_impl(ref_sorted.data(), grp_data.data(), + grp_offsets.data(), rank_sums.data(), + tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "ref_sorted"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, + "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); +} + +NB_MODULE(_wilcoxon_ovo_cuda, m) { + REGISTER_GPU_BINDINGS(register_bindings, m); + + m.def( + "ovo_streaming_csc_host", + [](host_array h_data, host_array h_indices, + host_array h_indptr, host_array h_ref_row_map, + host_array h_grp_row_map, + host_array h_grp_offsets, + host_array_2d h_rank_sums, host_array_2d h_tie_corr, + int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + ovo_streaming_csc_host_impl( + h_data.data(), h_indices.data(), h_indptr.data(), + h_ref_row_map.data(), h_grp_row_map.data(), + h_grp_offsets.data(), h_rank_sums.data(), h_tie_corr.data(), + n_ref, n_all_grp, n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_map"_a, + "h_grp_row_map"_a, "h_grp_offsets"_a, "h_rank_sums"_a, "h_tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_rows"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovo_streaming_csr_host", + [](host_array h_data, host_array h_indices, + host_array h_indptr, host_array h_ref_row_ids, + host_array h_grp_row_ids, + host_array h_grp_offsets, + host_array_2d h_rank_sums, host_array_2d h_tie_corr, + int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, + int nnz, bool compute_tie_corr, int sub_batch_cols) { + ovo_streaming_csr_host_impl( + h_data.data(), h_indices.data(), h_indptr.data(), + h_ref_row_ids.data(), h_grp_row_ids.data(), + h_grp_offsets.data(), h_rank_sums.data(), h_tie_corr.data(), + n_ref, n_all_grp, n_rows, n_cols, n_groups, nnz, + compute_tie_corr, sub_batch_cols); + }, + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_ids"_a, + "h_grp_row_ids"_a, "h_grp_offsets"_a, "h_rank_sums"_a, "h_tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_rows"_a, "n_cols"_a, + "n_groups"_a, "nnz"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovo_streaming_dense_host", + [](host_array_2d h_block, + host_array h_ref_row_ids, + host_array h_grp_row_ids, + host_array h_grp_offsets, + host_array_2d h_rank_sums, host_array_2d h_tie_corr, + int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + ovo_streaming_dense_host_impl( + h_block.data(), h_ref_row_ids.data(), h_grp_row_ids.data(), + h_grp_offsets.data(), h_rank_sums.data(), h_tie_corr.data(), + n_ref, n_all_grp, n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "h_block"_a, "h_ref_row_ids"_a, "h_grp_row_ids"_a, "h_grp_offsets"_a, + "h_rank_sums"_a, "h_tie_corr"_a, nb::kw_only(), "n_ref"_a, + "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon_streaming/wilcoxon_streaming.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu similarity index 50% rename from src/rapids_singlecell/_cuda/wilcoxon_streaming/wilcoxon_streaming.cu rename to src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu index 2b0a1798..7caa29bf 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon_streaming/wilcoxon_streaming.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu @@ -4,58 +4,11 @@ #include #include "../nb_types.h" -#include "../wilcoxon/kernels_wilcoxon.cuh" -#include "../wilcoxon/kernels_wilcoxon_ovo.cuh" +#include "wilcoxon_common.cuh" +#include "kernels_wilcoxon.cuh" using namespace nb::literals; -constexpr int WARP_SIZE = 32; -constexpr int MAX_THREADS_PER_BLOCK = 512; -constexpr int N_STREAMS = 4; -constexpr int SUB_BATCH_COLS = 32; -constexpr int BEGIN_BIT = 0; -constexpr int END_BIT = 32; - -static inline int round_up_to_warp(int n) { - int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; -} - -/** - * Extract dense F-order float32 block from CSR. - * All rows, column range [col_start, col_stop). - * One thread per row, binary search for col_start. - * Output must be pre-zeroed. - */ -__global__ void csr_extract_f32_kernel(const float* __restrict__ data, - const int* __restrict__ indices, - const int* __restrict__ indptr, - float* __restrict__ out, int n_rows, - int col_start, int col_stop) { - int row = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= n_rows) return; - - int rs = indptr[row]; - int re = indptr[row + 1]; - - // Binary search for col_start - int lo = rs, hi = re; - while (lo < hi) { - int m = (lo + hi) >> 1; - if (indices[m] < col_start) - lo = m + 1; - else - hi = m; - } - - int n_cols = col_stop - col_start; - for (int p = lo; p < re; ++p) { - int c = indices[p]; - if (c >= col_stop) break; - out[(long long)(c - col_start) * n_rows + row] = data[p]; - } -} - /** * Extract dense F-order float32 block from CSC. * Column range [col_start, col_stop). @@ -93,6 +46,22 @@ __global__ void fill_row_indices_kernel(int* __restrict__ vals, int n_rows, } } +/** + * Launch csr_extract_dense_kernel for ALL rows of a CSR matrix. + * Creates a temporary identity row_ids array [0,1,...,n_rows-1]. + */ +static void csr_extract_all_rows(const float* data, const int* indices, + const int* indptr, float* out, int n_rows, + int col_start, int col_stop, RmmPool& pool, + cudaStream_t stream) { + int* row_ids = pool.alloc(n_rows); + fill_row_indices_kernel<<<1, 256, 0, stream>>>(row_ids, n_rows, 1); + int tpb = round_up_to_warp(n_rows); + int blk = (n_rows + tpb - 1) / tpb; + csr_extract_dense_kernel<<>>( + data, indices, indptr, row_ids, out, n_rows, col_start, col_stop); +} + /** * Streaming OVR pipeline. * @@ -127,36 +96,32 @@ static void ovr_streaming_impl(const float* block, const int* group_codes, std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - // Per-stream buffers (allocated once, reused across sub-batches) + // Allocate per-stream buffers via RMM pool + RmmPool pool; struct StreamBuf { float* keys_out; int* vals_in; int* vals_out; int* seg_offsets; uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; }; - std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { - cudaMalloc(&bufs[s].keys_out, sub_items * sizeof(float)); - cudaMalloc(&bufs[s].vals_in, sub_items * sizeof(int)); - cudaMalloc(&bufs[s].vals_out, sub_items * sizeof(int)); - cudaMalloc(&bufs[s].seg_offsets, (sub_batch_cols + 1) * sizeof(int)); - cudaMalloc(&bufs[s].cub_temp, cub_temp_bytes); + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); } int tpb_rank = round_up_to_warp(n_rows); int smem_rank = (4 * n_groups + 32) * sizeof(double); - // Allocate sub-batch output buffers per stream - std::vector sub_rank_sums(n_streams); - std::vector sub_tie_corr(n_streams); - for (int s = 0; s < n_streams; s++) { - cudaMalloc(&sub_rank_sums[s], - (size_t)n_groups * sub_batch_cols * sizeof(double)); - cudaMalloc(&sub_tie_corr[s], sub_batch_cols * sizeof(double)); - } - // Process sub-batches round-robin across streams int col = 0; int batch_idx = 0; @@ -167,16 +132,8 @@ static void ovr_streaming_impl(const float* block, const int* group_codes, auto stream = streams[s]; auto& buf = bufs[s]; - // Fill segment offsets: [0, n_rows, 2*n_rows, ...] - { - int* h_off = new int[sb_cols + 1]; - for (int i = 0; i <= sb_cols; i++) h_off[i] = i * n_rows; - cudaMemcpyAsync(buf.seg_offsets, h_off, (sb_cols + 1) * sizeof(int), - cudaMemcpyHostToDevice, stream); - delete[] h_off; - } - - // Fill row indices + // Fill segment offsets + row indices + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); fill_row_indices_kernel<<>>(buf.vals_in, n_rows, sb_cols); @@ -190,8 +147,8 @@ static void ovr_streaming_impl(const float* block, const int* group_codes, // Fused rank sums into sub-batch buffer rank_sums_from_sorted_kernel<<>>( - buf.keys_out, buf.vals_out, group_codes, sub_rank_sums[s], - sub_tie_corr[s], nullptr, nullptr, nullptr, n_rows, sb_cols, + buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, + buf.sub_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, n_groups, compute_tie_corr, false); // Copy sub-batch results to global output (row-major scatter) @@ -199,11 +156,11 @@ static void ovr_streaming_impl(const float* block, const int* group_codes, // [g*n_cols+c] sub output is (n_groups, sb_cols): group g, local col lc // → [g*sb_cols+lc] cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - sub_rank_sums[s], sb_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, cudaMemcpyDeviceToDevice, stream); if (compute_tie_corr) { - cudaMemcpyAsync(tie_corr + col, sub_tie_corr[s], + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, stream); } @@ -217,17 +174,7 @@ static void ovr_streaming_impl(const float* block, const int* group_codes, cudaStreamSynchronize(streams[s]); } - // Cleanup - for (int s = 0; s < n_streams; s++) { - cudaFree(bufs[s].keys_out); - cudaFree(bufs[s].vals_in); - cudaFree(bufs[s].vals_out); - cudaFree(bufs[s].seg_offsets); - cudaFree(bufs[s].cub_temp); - cudaFree(sub_rank_sums[s]); - cudaFree(sub_tie_corr[s]); - cudaStreamDestroy(streams[s]); - } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } /** @@ -260,36 +207,33 @@ static void ovr_streaming_csr_impl( std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + // Allocate per-stream buffers via RMM pool + RmmPool pool; struct StreamBuf { - float* dense; // extracted dense sub-batch + float* dense; float* keys_out; int* vals_in; int* vals_out; int* seg_offsets; uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; }; - std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { - cudaMalloc(&bufs[s].dense, sub_items * sizeof(float)); - cudaMalloc(&bufs[s].keys_out, sub_items * sizeof(float)); - cudaMalloc(&bufs[s].vals_in, sub_items * sizeof(int)); - cudaMalloc(&bufs[s].vals_out, sub_items * sizeof(int)); - cudaMalloc(&bufs[s].seg_offsets, (sub_batch_cols + 1) * sizeof(int)); - cudaMalloc(&bufs[s].cub_temp, cub_temp_bytes); - } - - std::vector sub_rank_sums(n_streams); - std::vector sub_tie_corr(n_streams); - for (int s = 0; s < n_streams; s++) { - cudaMalloc(&sub_rank_sums[s], - (size_t)n_groups * sub_batch_cols * sizeof(double)); - cudaMalloc(&sub_tie_corr[s], sub_batch_cols * sizeof(double)); + bufs[s].dense = pool.alloc(sub_items); + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); } int tpb_rank = round_up_to_warp(n_rows); int smem_rank = (4 * n_groups + 32) * sizeof(double); - int tpb_extract = round_up_to_warp(n_rows); int col = 0; int batch_idx = 0; @@ -303,20 +247,12 @@ static void ovr_streaming_csr_impl( // Zero dense buffer cudaMemsetAsync(buf.dense, 0, sb_items * sizeof(float), stream); - // Extract dense columns from CSR - int extract_blocks = (n_rows + tpb_extract - 1) / tpb_extract; - csr_extract_f32_kernel<<>>( - csr_data, csr_indices, csr_indptr, buf.dense, n_rows, col, - col + sb_cols); + // Extract dense columns from CSR (all rows) + csr_extract_all_rows(csr_data, csr_indices, csr_indptr, buf.dense, + n_rows, col, col + sb_cols, pool, stream); // Fill segment offsets + row indices - { - int* h_off = new int[sb_cols + 1]; - for (int i = 0; i <= sb_cols; i++) h_off[i] = i * n_rows; - cudaMemcpyAsync(buf.seg_offsets, h_off, (sb_cols + 1) * sizeof(int), - cudaMemcpyHostToDevice, stream); - delete[] h_off; - } + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); fill_row_indices_kernel<<>>(buf.vals_in, n_rows, sb_cols); @@ -329,17 +265,17 @@ static void ovr_streaming_csr_impl( // Fused rank sums rank_sums_from_sorted_kernel<<>>( - buf.keys_out, buf.vals_out, group_codes, sub_rank_sums[s], - sub_tie_corr[s], nullptr, nullptr, nullptr, n_rows, sb_cols, + buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, + buf.sub_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, n_groups, compute_tie_corr, false); // Scatter to global output cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - sub_rank_sums[s], sb_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, cudaMemcpyDeviceToDevice, stream); if (compute_tie_corr) { - cudaMemcpyAsync(tie_corr + col, sub_tie_corr[s], + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, stream); } @@ -350,17 +286,7 @@ static void ovr_streaming_csr_impl( for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); - for (int s = 0; s < n_streams; s++) { - cudaFree(bufs[s].dense); - cudaFree(bufs[s].keys_out); - cudaFree(bufs[s].vals_in); - cudaFree(bufs[s].vals_out); - cudaFree(bufs[s].seg_offsets); - cudaFree(bufs[s].cub_temp); - cudaFree(sub_rank_sums[s]); - cudaFree(sub_tie_corr[s]); - cudaStreamDestroy(streams[s]); - } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } /** @@ -393,6 +319,8 @@ static void ovr_streaming_csc_impl( std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + // Allocate per-stream buffers via RMM pool + RmmPool pool; struct StreamBuf { float* dense; float* keys_out; @@ -400,24 +328,20 @@ static void ovr_streaming_csc_impl( int* vals_out; int* seg_offsets; uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; }; - std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { - cudaMalloc(&bufs[s].dense, sub_items * sizeof(float)); - cudaMalloc(&bufs[s].keys_out, sub_items * sizeof(float)); - cudaMalloc(&bufs[s].vals_in, sub_items * sizeof(int)); - cudaMalloc(&bufs[s].vals_out, sub_items * sizeof(int)); - cudaMalloc(&bufs[s].seg_offsets, (sub_batch_cols + 1) * sizeof(int)); - cudaMalloc(&bufs[s].cub_temp, cub_temp_bytes); - } - - std::vector sub_rank_sums(n_streams); - std::vector sub_tie_corr(n_streams); - for (int s = 0; s < n_streams; s++) { - cudaMalloc(&sub_rank_sums[s], - (size_t)n_groups * sub_batch_cols * sizeof(double)); - cudaMalloc(&sub_tie_corr[s], sub_batch_cols * sizeof(double)); + bufs[s].dense = pool.alloc(sub_items); + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); } int tpb_rank = round_up_to_warp(n_rows); @@ -440,13 +364,7 @@ static void ovr_streaming_csc_impl( csc_data, csc_indices, csc_indptr, buf.dense, n_rows, col); // Fill segment offsets + row indices - { - int* h_off = new int[sb_cols + 1]; - for (int i = 0; i <= sb_cols; i++) h_off[i] = i * n_rows; - cudaMemcpyAsync(buf.seg_offsets, h_off, (sb_cols + 1) * sizeof(int), - cudaMemcpyHostToDevice, stream); - delete[] h_off; - } + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); fill_row_indices_kernel<<>>(buf.vals_in, n_rows, sb_cols); @@ -459,17 +377,17 @@ static void ovr_streaming_csc_impl( // Fused rank sums rank_sums_from_sorted_kernel<<>>( - buf.keys_out, buf.vals_out, group_codes, sub_rank_sums[s], - sub_tie_corr[s], nullptr, nullptr, nullptr, n_rows, sb_cols, + buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, + buf.sub_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, n_groups, compute_tie_corr, false); // Scatter to global output cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - sub_rank_sums[s], sb_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, cudaMemcpyDeviceToDevice, stream); if (compute_tie_corr) { - cudaMemcpyAsync(tie_corr + col, sub_tie_corr[s], + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, stream); } @@ -480,17 +398,7 @@ static void ovr_streaming_csc_impl( for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); - for (int s = 0; s < n_streams; s++) { - cudaFree(bufs[s].dense); - cudaFree(bufs[s].keys_out); - cudaFree(bufs[s].vals_in); - cudaFree(bufs[s].vals_out); - cudaFree(bufs[s].seg_offsets); - cudaFree(bufs[s].cub_temp); - cudaFree(sub_rank_sums[s]); - cudaFree(sub_tie_corr[s]); - cudaStreamDestroy(streams[s]); - } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } /** @@ -532,17 +440,14 @@ static void ovr_streaming_csc_host_impl( std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - // Group codes on GPU (transferred once) - int* d_group_codes; - cudaMalloc(&d_group_codes, n_rows * sizeof(int)); - cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), - cudaMemcpyHostToDevice); - + // Allocate per-stream buffers via RMM pool + RmmPool pool; + int* d_group_codes = pool.alloc(n_rows); struct StreamBuf { - float* d_sparse_data; // H2D sparse values - int* d_sparse_indices; // H2D sparse row indices - int* d_indptr; // H2D indptr slice (sb_cols + 1) - float* dense; // extracted dense + float* d_sparse_data; + int* d_sparse_indices; + int* d_indptr; + float* dense; float* keys_out; int* vals_in; int* vals_out; @@ -551,23 +456,26 @@ static void ovr_streaming_csc_host_impl( double* d_rank_sums; double* d_tie_corr; }; - std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { - cudaMalloc(&bufs[s].d_sparse_data, max_nnz * sizeof(float)); - cudaMalloc(&bufs[s].d_sparse_indices, max_nnz * sizeof(int)); - cudaMalloc(&bufs[s].d_indptr, (sub_batch_cols + 1) * sizeof(int)); - cudaMalloc(&bufs[s].dense, sub_items * sizeof(float)); - cudaMalloc(&bufs[s].keys_out, sub_items * sizeof(float)); - cudaMalloc(&bufs[s].vals_in, sub_items * sizeof(int)); - cudaMalloc(&bufs[s].vals_out, sub_items * sizeof(int)); - cudaMalloc(&bufs[s].seg_offsets, (sub_batch_cols + 1) * sizeof(int)); - cudaMalloc(&bufs[s].cub_temp, cub_temp_bytes); - cudaMalloc(&bufs[s].d_rank_sums, - (size_t)n_groups * sub_batch_cols * sizeof(double)); - cudaMalloc(&bufs[s].d_tie_corr, sub_batch_cols * sizeof(double)); + bufs[s].d_sparse_data = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_indptr = pool.alloc(sub_batch_cols + 1); + bufs[s].dense = pool.alloc(sub_items); + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); } + // Group codes on GPU (transferred once) + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + int tpb_rank = round_up_to_warp(n_rows); int smem_rank = (4 * n_groups + 32) * sizeof(double); @@ -601,12 +509,11 @@ static void ovr_streaming_csc_host_impl( // Transfer adjusted indptr (rebased to 0) // h_indptr[col..col+sb_cols] - h_indptr[col] { - int* h_adj = new int[sb_cols + 1]; + std::vector h_adj(sb_cols + 1); for (int i = 0; i <= sb_cols; i++) h_adj[i] = h_indptr[col + i] - ptr_start; - cudaMemcpyAsync(buf.d_indptr, h_adj, (sb_cols + 1) * sizeof(int), - cudaMemcpyHostToDevice, stream); - delete[] h_adj; + cudaMemcpy(buf.d_indptr, h_adj.data(), (sb_cols + 1) * sizeof(int), + cudaMemcpyHostToDevice); } // Zero dense buffer @@ -619,13 +526,7 @@ static void ovr_streaming_csc_host_impl( n_rows, 0); // Fill segment offsets + row indices - { - int* h_off = new int[sb_cols + 1]; - for (int i = 0; i <= sb_cols; i++) h_off[i] = i * n_rows; - cudaMemcpyAsync(buf.seg_offsets, h_off, (sb_cols + 1) * sizeof(int), - cudaMemcpyHostToDevice, stream); - delete[] h_off; - } + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); fill_row_indices_kernel<<>>(buf.vals_in, n_rows, sb_cols); @@ -664,21 +565,7 @@ static void ovr_streaming_csc_host_impl( cudaHostUnregister(h_rank_sums); cudaHostUnregister(h_tie_corr); - cudaFree(d_group_codes); - for (int s = 0; s < n_streams; s++) { - cudaFree(bufs[s].d_sparse_data); - cudaFree(bufs[s].d_sparse_indices); - cudaFree(bufs[s].d_indptr); - cudaFree(bufs[s].dense); - cudaFree(bufs[s].keys_out); - cudaFree(bufs[s].vals_in); - cudaFree(bufs[s].vals_out); - cudaFree(bufs[s].seg_offsets); - cudaFree(bufs[s].cub_temp); - cudaFree(bufs[s].d_rank_sums); - cudaFree(bufs[s].d_tie_corr); - cudaStreamDestroy(streams[s]); - } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } /** @@ -710,13 +597,11 @@ static void ovr_streaming_dense_host_impl( std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - int* d_group_codes; - cudaMalloc(&d_group_codes, n_rows * sizeof(int)); - cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), - cudaMemcpyHostToDevice); - + // Allocate per-stream buffers via RMM pool + RmmPool pool; + int* d_group_codes = pool.alloc(n_rows); struct StreamBuf { - float* d_block; // H2D dense sub-batch + float* d_block; float* keys_out; int* vals_in; int* vals_out; @@ -725,20 +610,23 @@ static void ovr_streaming_dense_host_impl( double* d_rank_sums; double* d_tie_corr; }; - std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { - cudaMalloc(&bufs[s].d_block, sub_items * sizeof(float)); - cudaMalloc(&bufs[s].keys_out, sub_items * sizeof(float)); - cudaMalloc(&bufs[s].vals_in, sub_items * sizeof(int)); - cudaMalloc(&bufs[s].vals_out, sub_items * sizeof(int)); - cudaMalloc(&bufs[s].seg_offsets, (sub_batch_cols + 1) * sizeof(int)); - cudaMalloc(&bufs[s].cub_temp, cub_temp_bytes); - cudaMalloc(&bufs[s].d_rank_sums, - (size_t)n_groups * sub_batch_cols * sizeof(double)); - cudaMalloc(&bufs[s].d_tie_corr, sub_batch_cols * sizeof(double)); + bufs[s].d_block = pool.alloc(sub_items); + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); } + // Group codes on GPU (transferred once) + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + int tpb_rank = round_up_to_warp(n_rows); int smem_rank = (4 * n_groups + 32) * sizeof(double); @@ -764,13 +652,7 @@ static void ovr_streaming_dense_host_impl( stream); // Fill segment offsets + row indices - { - int* h_off = new int[sb_cols + 1]; - for (int i = 0; i <= sb_cols; i++) h_off[i] = i * n_rows; - cudaMemcpyAsync(buf.seg_offsets, h_off, (sb_cols + 1) * sizeof(int), - cudaMemcpyHostToDevice, stream); - delete[] h_off; - } + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); fill_row_indices_kernel<<>>(buf.vals_in, n_rows, sb_cols); @@ -808,319 +690,7 @@ static void ovr_streaming_dense_host_impl( cudaHostUnregister(h_rank_sums); cudaHostUnregister(h_tie_corr); - cudaFree(d_group_codes); - for (int s = 0; s < n_streams; s++) { - cudaFree(bufs[s].d_block); - cudaFree(bufs[s].keys_out); - cudaFree(bufs[s].vals_in); - cudaFree(bufs[s].vals_out); - cudaFree(bufs[s].seg_offsets); - cudaFree(bufs[s].cub_temp); - cudaFree(bufs[s].d_rank_sums); - cudaFree(bufs[s].d_tie_corr); - cudaStreamDestroy(streams[s]); - } -} - -/** - * Build segment offsets for CUB segmented sort of group data within a - * sub-batch. offset[c * n_groups + g] = c * n_all_grp + grp_offsets[g]. - * One thread per entry. - */ -__global__ void build_seg_offsets_kernel( - const int* __restrict__ grp_offsets, // (n_groups + 1,) - int* __restrict__ out, // (sb_cols * n_groups + 1,) - int n_all_grp, int n_groups, int sb_cols) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int total = sb_cols * n_groups + 1; - if (idx >= total) return; - if (idx == sb_cols * n_groups) { - out[idx] = sb_cols * n_all_grp; - } else { - int c = idx / n_groups; - int g = idx % n_groups; - out[idx] = c * n_all_grp + grp_offsets[g]; - } -} - -/** - * Streaming OVO pipeline. - * - * Takes pre-sorted reference (float32 F-order), unsorted group data (float32 - * F-order with group offsets), and produces rank_sums + tie_corr. - * - * For each sub-batch of columns: - * 1. CUB segmented sort-keys of group data (one segment per group per col) - * 2. batched_rank_sums_presorted_kernel (binary search in sorted ref) - */ -static void ovo_streaming_impl(const float* ref_sorted, const float* grp_data, - const int* grp_offsets, double* rank_sums, - double* tie_corr, int n_ref, int n_all_grp, - int n_cols, int n_groups, bool compute_tie_corr, - int sub_batch_cols) { - if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; - - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - - size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; - int max_n_seg = n_groups * sub_batch_cols; - size_t cub_temp_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys(nullptr, cub_temp_bytes, fk, fk, - (int)sub_grp_items, max_n_seg, - doff, doff + 1, 0, 32); - } - - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - - struct StreamBuf { - float* grp_sorted; - int* seg_offsets; - uint8_t* cub_temp; - double* sub_rank_sums; - double* sub_tie_corr; - }; - - std::vector bufs(n_streams); - for (int s = 0; s < n_streams; s++) { - cudaMalloc(&bufs[s].grp_sorted, sub_grp_items * sizeof(float)); - cudaMalloc(&bufs[s].seg_offsets, (max_n_seg + 1) * sizeof(int)); - cudaMalloc(&bufs[s].cub_temp, cub_temp_bytes); - cudaMalloc(&bufs[s].sub_rank_sums, - (size_t)n_groups * sub_batch_cols * sizeof(double)); - cudaMalloc(&bufs[s].sub_tie_corr, - (size_t)n_groups * sub_batch_cols * sizeof(double)); - } - - // Import the presorted kernel from the OVO header - // (included via kernels_wilcoxon_ovo.cuh) - int tpb_rank = round_up_to_warp(std::min(n_all_grp, MAX_THREADS_PER_BLOCK)); - - int col = 0; - int batch_idx = 0; - while (col < n_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_n_seg = n_groups * sb_cols; - int sb_grp_items = n_all_grp * sb_cols; - int s = batch_idx % n_streams; - auto stream = streams[s]; - auto& buf = bufs[s]; - - // Build segment offsets on device - { - int total = sb_n_seg + 1; - int blk = (total + 255) / 256; - build_seg_offsets_kernel<<>>( - grp_offsets, buf.seg_offsets, n_all_grp, n_groups, sb_cols); - } - - // Sort group data for this sub-batch - const float* grp_in = grp_data + (long long)col * n_all_grp; - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, grp_in, buf.grp_sorted, sb_grp_items, sb_n_seg, - buf.seg_offsets, buf.seg_offsets + 1, 0, 32, stream); - - // Rank sums: binary search sorted ref for each group element - const float* ref_sub = ref_sorted + (long long)col * n_ref; - dim3 grid(sb_cols, n_groups); - batched_rank_sums_presorted_kernel<<>>( - ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, - buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr); - - // Scatter sub-batch results to global output - cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - buf.sub_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - if (compute_tie_corr) { - cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), - buf.sub_tie_corr, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - } - - col += sb_cols; - batch_idx++; - } - - for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); - - for (int s = 0; s < n_streams; s++) { - cudaFree(bufs[s].grp_sorted); - cudaFree(bufs[s].seg_offsets); - cudaFree(bufs[s].cub_temp); - cudaFree(bufs[s].sub_rank_sums); - cudaFree(bufs[s].sub_tie_corr); - cudaStreamDestroy(streams[s]); - } -} - -/** - * Multi-GPU OVO streaming pipeline with host data. - * - * Ref block is sorted on GPU 0, then P2P copied to other GPUs. - * Group data is streamed from host to each GPU's streams. - */ -static void ovo_streaming_multigpu_impl( - const float* h_ref_sorted, const float* h_grp_data, - const int* h_grp_offsets, double* h_rank_sums, double* h_tie_corr, - int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, - int sub_batch_cols, const int* h_device_ids, int n_devices) { - if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; - - constexpr int STREAMS_PER_GPU = 2; - - // CUB temp for segmented sort of group data - int max_n_seg = n_groups * sub_batch_cols; - size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; - size_t cub_temp_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys(nullptr, cub_temp_bytes, fk, fk, - (int)sub_grp_items, max_n_seg, - doff, doff + 1, 0, 32); - } - - int tpb = round_up_to_warp(std::min(n_all_grp, MAX_THREADS_PER_BLOCK)); - - struct GpuCtx { - int device_id; - cudaStream_t streams[2]; - float* d_ref_sorted; // full ref, copied once - int* d_grp_offsets; - struct { - float* d_grp_data; - float* d_grp_sorted; - int* d_seg_offsets; - uint8_t* d_cub_temp; - double* d_rank_sums; - double* d_tie_corr; - } buf[2]; - }; - - std::vector gpus(n_devices); - - // Phase 1: allocate + upload ref + offsets to each GPU - for (int d = 0; d < n_devices; d++) { - auto& g = gpus[d]; - g.device_id = h_device_ids[d]; - cudaSetDevice(g.device_id); - - for (int s = 0; s < STREAMS_PER_GPU; s++) - cudaStreamCreate(&g.streams[s]); - - size_t ref_size = (size_t)n_ref * n_cols; - cudaMalloc(&g.d_ref_sorted, ref_size * sizeof(float)); - cudaMemcpyAsync(g.d_ref_sorted, h_ref_sorted, ref_size * sizeof(float), - cudaMemcpyHostToDevice, g.streams[0]); - - cudaMalloc(&g.d_grp_offsets, (n_groups + 1) * sizeof(int)); - cudaMemcpyAsync(g.d_grp_offsets, h_grp_offsets, - (n_groups + 1) * sizeof(int), cudaMemcpyHostToDevice, - g.streams[0]); - - for (int s = 0; s < STREAMS_PER_GPU; s++) { - cudaMalloc(&g.buf[s].d_grp_data, sub_grp_items * sizeof(float)); - cudaMalloc(&g.buf[s].d_grp_sorted, sub_grp_items * sizeof(float)); - cudaMalloc(&g.buf[s].d_seg_offsets, (max_n_seg + 1) * sizeof(int)); - cudaMalloc(&g.buf[s].d_cub_temp, cub_temp_bytes); - cudaMalloc(&g.buf[s].d_rank_sums, - (size_t)n_groups * sub_batch_cols * sizeof(double)); - cudaMalloc(&g.buf[s].d_tie_corr, - (size_t)n_groups * sub_batch_cols * sizeof(double)); - } - } - - // Phase 2: process sub-batches - int col = 0; - int batch_idx = 0; - while (col < n_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_n_seg = n_groups * sb_cols; - int sb_grp_items = n_all_grp * sb_cols; - - int d = (batch_idx / STREAMS_PER_GPU) % n_devices; - int s = batch_idx % STREAMS_PER_GPU; - auto& g = gpus[d]; - auto stream = g.streams[s]; - auto& buf = g.buf[s]; - - cudaSetDevice(g.device_id); - - // H2D: group data sub-batch - cudaMemcpyAsync(buf.d_grp_data, h_grp_data + (long long)col * n_all_grp, - sb_grp_items * sizeof(float), cudaMemcpyHostToDevice, - stream); - - // Build segment offsets on device - { - int total = sb_n_seg + 1; - int blk = (total + 255) / 256; - build_seg_offsets_kernel<<>>( - g.d_grp_offsets, buf.d_seg_offsets, n_all_grp, n_groups, - sb_cols); - } - - // Sort group data - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.d_cub_temp, temp, buf.d_grp_data, buf.d_grp_sorted, - sb_grp_items, sb_n_seg, buf.d_seg_offsets, buf.d_seg_offsets + 1, 0, - 32, stream); - - // Rank sums - const float* ref_sub = g.d_ref_sorted + (long long)col * n_ref; - dim3 grid(sb_cols, n_groups); - batched_rank_sums_presorted_kernel<<>>( - ref_sub, buf.d_grp_sorted, g.d_grp_offsets, buf.d_rank_sums, - buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr); - - // D2H: scatter results - cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), - buf.d_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToHost, stream); - if (compute_tie_corr) { - cudaMemcpy2DAsync(h_tie_corr + col, n_cols * sizeof(double), - buf.d_tie_corr, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToHost, stream); - } - - col += sb_cols; - batch_idx++; - } - - // Phase 3: sync + cleanup - for (int d = 0; d < n_devices; d++) { - cudaSetDevice(gpus[d].device_id); - for (int s = 0; s < STREAMS_PER_GPU; s++) - cudaStreamSynchronize(gpus[d].streams[s]); - } - for (int d = 0; d < n_devices; d++) { - cudaSetDevice(gpus[d].device_id); - cudaFree(gpus[d].d_ref_sorted); - cudaFree(gpus[d].d_grp_offsets); - for (int s = 0; s < STREAMS_PER_GPU; s++) { - cudaFree(gpus[d].buf[s].d_grp_data); - cudaFree(gpus[d].buf[s].d_grp_sorted); - cudaFree(gpus[d].buf[s].d_seg_offsets); - cudaFree(gpus[d].buf[s].d_cub_temp); - cudaFree(gpus[d].buf[s].d_rank_sums); - cudaFree(gpus[d].buf[s].d_tie_corr); - cudaStreamDestroy(gpus[d].streams[s]); - } - } - cudaSetDevice(h_device_ids[0]); + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } // ============================================================================ @@ -1129,7 +699,9 @@ static void ovo_streaming_multigpu_impl( template void register_bindings(nb::module_& m) { - m.doc() = "Streaming Wilcoxon pipeline with multi-stream overlap"; + m.doc() = "CUDA kernels for Wilcoxon rank-sum test (OVR)"; + + // ---- Streaming pipelines ---- m.def( "ovr_streaming", @@ -1184,28 +756,9 @@ void register_bindings(nb::module_& m) { "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); - - m.def( - "ovo_streaming", - [](gpu_array_f ref_sorted, - gpu_array_f grp_data, - gpu_array_c grp_offsets, - gpu_array_c rank_sums, - gpu_array_c tie_corr, int n_ref, int n_all_grp, - int n_cols, int n_groups, bool compute_tie_corr, - int sub_batch_cols) { - ovo_streaming_impl(ref_sorted.data(), grp_data.data(), - grp_offsets.data(), rank_sums.data(), - tie_corr.data(), n_ref, n_all_grp, n_cols, - n_groups, compute_tie_corr, sub_batch_cols); - }, - "ref_sorted"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, - "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, - "n_groups"_a, "compute_tie_corr"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); } -NB_MODULE(_wilcoxon_streaming_cuda, m) { +NB_MODULE(_wilcoxon_ovr_cuda, m) { REGISTER_GPU_BINDINGS(register_bindings, m); m.def( @@ -1240,44 +793,4 @@ NB_MODULE(_wilcoxon_streaming_cuda, m) { "h_block"_a, "h_group_codes"_a, "h_rank_sums"_a, "h_tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); - - m.def( - "ovo_streaming_multigpu", - [](host_array_2d h_ref_sorted, - host_array_2d h_grp_data, - host_array h_grp_offsets, - host_array_2d h_rank_sums, host_array_2d h_tie_corr, - int n_ref, int n_all_grp, int n_cols, int n_groups, - bool compute_tie_corr, int sub_batch_cols, - host_array device_ids) { - // Pin host arrays - size_t ref_bytes = (size_t)n_ref * n_cols * sizeof(float); - size_t grp_bytes = (size_t)n_all_grp * n_cols * sizeof(float); - size_t rs_bytes = (size_t)n_groups * n_cols * sizeof(double); - cudaHostRegister(const_cast(h_ref_sorted.data()), ref_bytes, - 0); - cudaHostRegister(const_cast(h_grp_data.data()), grp_bytes, - 0); - cudaHostRegister(const_cast(h_grp_offsets.data()), - (n_groups + 1) * sizeof(int), 0); - cudaHostRegister(h_rank_sums.data(), rs_bytes, 0); - cudaHostRegister(h_tie_corr.data(), - (size_t)n_groups * n_cols * sizeof(double), 0); - - ovo_streaming_multigpu_impl( - h_ref_sorted.data(), h_grp_data.data(), h_grp_offsets.data(), - h_rank_sums.data(), h_tie_corr.data(), n_ref, n_all_grp, n_cols, - n_groups, compute_tie_corr, sub_batch_cols, device_ids.data(), - static_cast(device_ids.size())); - - cudaHostUnregister(const_cast(h_ref_sorted.data())); - cudaHostUnregister(const_cast(h_grp_data.data())); - cudaHostUnregister(const_cast(h_grp_offsets.data())); - cudaHostUnregister(h_rank_sums.data()); - cudaHostUnregister(h_tie_corr.data()); - }, - "h_ref_sorted"_a, "h_grp_data"_a, "h_grp_offsets"_a, "h_rank_sums"_a, - "h_tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, - "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS, - "device_ids"_a); } diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 5f46f110..5c36585e 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -9,7 +9,7 @@ import numpy as np import scipy.sparse as sp -from rapids_singlecell._cuda import _wilcoxon_cuda as _wc +from rapids_singlecell._cuda import _wilcoxon_ovo_cuda as _wc if TYPE_CHECKING: from numpy.typing import NDArray @@ -33,17 +33,27 @@ def _to_gpu_native(X, n_rows: int, n_cols: int): if cpsp.issparse(X): return X - # Host sparse → GPU sparse, same format + # Host sparse → GPU sparse, same format. + # Downcast indices to int32 on host before transfer (column indices + # always fit in int32; scipy may use int64 when nnz > 2^31). if isinstance(X, sp.spmatrix | sp.sparray): if sp.issparse(X) and X.format == "csc": csc = X if X.format == "csc" else X.tocsc() return cpsp.csc_matrix( - (cp.asarray(csc.data), cp.asarray(csc.indices), cp.asarray(csc.indptr)), + ( + cp.asarray(csc.data), + cp.asarray(csc.indices.astype(np.int32, copy=False)), + cp.asarray(csc.indptr), + ), shape=(n_rows, n_cols), ) csr = X.tocsr() if X.format != "csr" else X return cpsp.csr_matrix( - (cp.asarray(csr.data), cp.asarray(csr.indices), cp.asarray(csr.indptr)), + ( + cp.asarray(csr.data), + cp.asarray(csr.indices.astype(np.int32, copy=False)), + cp.asarray(csr.indptr), + ), shape=(n_rows, n_cols), ) @@ -62,10 +72,10 @@ def _extract_dense_block( *, csr_arrays: tuple[cp.ndarray, cp.ndarray, cp.ndarray] | None = None, ) -> cp.ndarray: - """Extract ``X[row_ids, start:stop]`` as dense F-order on GPU (native dtype). + """Extract ``X[row_ids, start:stop]`` as dense F-order on GPU. - The CSR kernel path outputs float64 (kernel writes double*). - All other paths preserve the input dtype. + CSR kernel path: outputs same dtype as CSR data (float32 or float64). + Other paths: preserve input dtype. """ if csr_arrays is not None: data, indices, indptr = csr_arrays @@ -74,19 +84,33 @@ def _extract_dense_block( row_ids = cp.arange(n_target, dtype=cp.int32) n_target = row_ids.shape[0] n_cols = stop - start - out = cp.zeros((n_target, n_cols), dtype=cp.float64, order="F") + out = cp.zeros((n_target, n_cols), dtype=data.dtype, order="F") if n_target > 0 and n_cols > 0: - _wc.csr_extract_dense( - data, - indices, - indptr, - row_ids, - out, - n_target=n_target, - col_start=start, - col_stop=stop, - stream=cp.cuda.get_current_stream().ptr, - ) + stream = cp.cuda.get_current_stream().ptr + if data.dtype == cp.float32: + _wc.csr_extract_dense_f32( + data, + indices, + indptr, + row_ids, + out, + n_target=n_target, + col_start=start, + col_stop=stop, + stream=stream, + ) + else: + _wc.csr_extract_dense( + data, + indices, + indptr, + row_ids, + out, + n_target=n_target, + col_start=start, + col_stop=stop, + stream=stream, + ) return out if isinstance(X, np.ndarray): @@ -175,8 +199,11 @@ def wilcoxon( n_cells, n_total_genes = X.shape group_sizes = rg.group_sizes + # Stats via Aggregate for both OVR and OVO — decoupled from sort. + # Aggregate reads original dtype for precision, accumulates in float64. + rg._basic_stats() + if rg.ireference is not None: - rg._init_stats_arrays(n_total_genes) return _wilcoxon_with_reference( rg, X, @@ -186,7 +213,6 @@ def wilcoxon( use_continuity=use_continuity, chunk_size=chunk_size, ) - rg._basic_stats() return _wilcoxon_vs_rest( rg, X, @@ -220,7 +246,7 @@ def _wilcoxon_vs_rest( Dispatches to CSR, CSC, or dense streaming kernel based on input format. No unnecessary format conversions. """ - from rapids_singlecell._cuda import _wilcoxon_streaming_cuda as _ws + from rapids_singlecell._cuda import _wilcoxon_ovr_cuda as _ovr n_groups = len(rg.groups_order) @@ -259,7 +285,7 @@ def _wilcoxon_vs_rest( tie_corr_np = np.ones(n_total_genes, dtype=np.float64) if host_csc: - _ws.ovr_streaming_csc_host( + _ovr.ovr_streaming_csc_host( X.data.astype(np.float32, copy=False), X.indices.astype(np.int32, copy=False), X.indptr.astype(np.int32, copy=False), @@ -273,7 +299,7 @@ def _wilcoxon_vs_rest( sub_batch_cols=STREAMING_SUB_BATCH, ) else: - _ws.ovr_streaming_dense_host( + _ovr.ovr_streaming_dense_host( np.asfortranarray(X.astype(np.float32, copy=False)), group_codes, rank_sums_np, @@ -301,7 +327,7 @@ def _wilcoxon_vs_rest( tie_corr = cp.ones(n_total_genes, dtype=cp.float64) if cpsp.isspmatrix_csr(X_gpu): - _ws.ovr_streaming_csr( + _ovr.ovr_streaming_csr( X_gpu.data.astype(cp.float32, copy=False), X_gpu.indices.astype(cp.int32, copy=False), X_gpu.indptr.astype(cp.int32, copy=False), @@ -315,7 +341,7 @@ def _wilcoxon_vs_rest( sub_batch_cols=STREAMING_SUB_BATCH, ) elif cpsp.isspmatrix_csc(X_gpu): - _ws.ovr_streaming_csc( + _ovr.ovr_streaming_csc( X_gpu.data.astype(cp.float32, copy=False), X_gpu.indices.astype(cp.int32, copy=False), X_gpu.indptr.astype(cp.int32, copy=False), @@ -330,7 +356,7 @@ def _wilcoxon_vs_rest( ) else: dense_f32 = cp.asfortranarray(X_gpu.astype(cp.float32, copy=False)) - _ws.ovr_streaming( + _ovr.ovr_streaming( dense_f32, group_codes_gpu, rank_sums, @@ -380,7 +406,6 @@ def _wilcoxon_with_reference( All test groups are processed in a single batched streaming kernel, eliminating per-group kernel launch overhead. """ - from rapids_singlecell._cuda import _wilcoxon_streaming_cuda as _ws n_cells = X.shape[0] n_groups = len(rg.groups_order) @@ -389,8 +414,6 @@ def _wilcoxon_with_reference( codes = rg.group_codes # ---- build row-index arrays ---- - ref_row_ids = cp.asarray(np.where(codes == ireference)[0], dtype=cp.int32) - test_group_indices: list[int] = [] all_grp_rows: list[np.ndarray] = [] offsets = [0] @@ -402,22 +425,11 @@ def _wilcoxon_with_reference( all_grp_rows.append(rows) offsets.append(offsets[-1] + len(rows)) - all_grp_row_ids = cp.asarray(np.concatenate(all_grp_rows), dtype=cp.int32) + all_grp_row_ids_np = np.concatenate(all_grp_rows) grp_offsets_gpu = cp.asarray(offsets, dtype=cp.int32) n_test = len(test_group_indices) - - # ---- move data to GPU ---- - X_gpu = _to_gpu_native(X, n_cells, n_total_genes) - - # For row extraction, CSR kernel is optimal. Dense uses cupy indexing. - csr_arrays = None - if cpsp.issparse(X_gpu): - csr_gpu = X_gpu.tocsr() if not cpsp.isspmatrix_csr(X_gpu) else X_gpu - csr_arrays = ( - csr_gpu.data.astype(cp.float64, copy=False), - csr_gpu.indices.astype(cp.int32, copy=False), - csr_gpu.indptr.astype(cp.int32, copy=False), - ) + n_all_grp = len(all_grp_row_ids_np) + ref_row_ids_np = np.where(codes == ireference)[0] # ---- warn for small groups ---- for gi in test_group_indices: @@ -435,52 +447,164 @@ def _wilcoxon_with_reference( [group_sizes[gi] for gi in test_group_indices], dtype=cp.float64 ) - # ---- extract ref + grp blocks (one-shot, all genes) ---- - ref_block = _extract_dense_block( - X_gpu, ref_row_ids, 0, n_total_genes, csr_arrays=csr_arrays - ) - grp_block = _extract_dense_block( - X_gpu, all_grp_row_ids, 0, n_total_genes, csr_arrays=csr_arrays - ) - n_all_grp = grp_block.shape[0] + # ---- build row maps (numpy, for both host and GPU CSC paths) ---- + ref_row_map_np = np.full(n_cells, -1, dtype=np.int32) + ref_row_map_np[ref_row_ids_np] = np.arange(n_ref, dtype=np.int32) + grp_row_map_np = np.full(n_cells, -1, dtype=np.int32) + grp_row_map_np[all_grp_row_ids_np] = np.arange(n_all_grp, dtype=np.int32) + offsets_np = np.asarray(offsets, dtype=np.int32) - # ---- stats via fused kernel ---- - _compute_grouped_stats( - rg, - ireference, - ref_block, - n_ref, - test_group_indices=test_group_indices, - grp_block=grp_block, - grp_offsets_gpu=grp_offsets_gpu, - n_test=n_test, - n_cols=n_total_genes, - start=0, - stop=n_total_genes, - ) + # ---- host-streaming paths: skip bulk transfer ---- + host_sparse = isinstance(X, sp.spmatrix | sp.sparray) + host_dense = isinstance(X, np.ndarray) + if host_sparse or host_dense: + if rg._compute_stats_in_chunks: + X_gpu_tmp = _to_gpu_native(X, n_cells, n_total_genes) + rg.X = X_gpu_tmp + rg._compute_stats_in_chunks = False + rg._basic_stats() + del X_gpu_tmp - # ---- sort reference once ---- - ref_sorted = _segmented_sort_columns( - ref_block, np.array([0, n_ref], dtype=np.int32), n_ref, n_total_genes, 1 - ) + rank_sums_np = np.empty((n_test, n_total_genes), dtype=np.float64) + tie_corr_np = np.ones((n_test, n_total_genes), dtype=np.float64) - # ---- streaming OVO: sort groups + binary search rank sums ---- - grp_f32 = cp.asfortranarray(grp_block.astype(cp.float32)) - rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) - tie_corr_arr = cp.empty((n_test, n_total_genes), dtype=cp.float64) - - _ws.ovo_streaming( - ref_sorted, - grp_f32, - grp_offsets_gpu, - rank_sums, - tie_corr_arr, - n_ref=n_ref, - n_all_grp=n_all_grp, - n_cols=n_total_genes, - n_groups=n_test, - compute_tie_corr=tie_correct, - ) + if host_sparse and X.format == "csc": + _wc.ovo_streaming_csc_host( + X.data.astype(np.float32, copy=False), + X.indices.astype(np.int32, copy=False), + X.indptr.astype(np.int32, copy=False), + ref_row_map_np, + grp_row_map_np, + offsets_np, + rank_sums_np, + tie_corr_np, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + elif host_sparse: + csr = X.tocsr() if X.format != "csr" else X + _wc.ovo_streaming_csr_host( + csr.data.astype(np.float32, copy=False), + csr.indices.astype(np.int32, copy=False), + csr.indptr.astype(np.int32, copy=False), + ref_row_ids_np.astype(np.int32, copy=False), + all_grp_row_ids_np.astype(np.int32, copy=False), + offsets_np, + rank_sums_np, + tie_corr_np, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_test, + nnz=csr.nnz, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + else: + _wc.ovo_streaming_dense_host( + np.asfortranarray(X.astype(np.float32, copy=False)), + ref_row_ids_np.astype(np.int32, copy=False), + all_grp_row_ids_np.astype(np.int32, copy=False), + offsets_np, + rank_sums_np, + tie_corr_np, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + + rank_sums = cp.asarray(rank_sums_np) + tie_corr_arr = cp.asarray(tie_corr_np) + + else: + # ---- GPU path: transfer once, then dispatch ---- + X_gpu = _to_gpu_native(X, n_cells, n_total_genes) + + if rg._compute_stats_in_chunks: + rg.X = X_gpu + rg._compute_stats_in_chunks = False + rg._basic_stats() + + ref_row_ids_gpu = cp.asarray(ref_row_ids_np, dtype=cp.int32) + all_grp_row_ids_gpu = cp.asarray(all_grp_row_ids_np, dtype=cp.int32) + + rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.empty((n_test, n_total_genes), dtype=cp.float64) + + if cpsp.isspmatrix_csc(X_gpu): + ref_row_map = cp.asarray(ref_row_map_np) + grp_row_map = cp.asarray(grp_row_map_np) + _wc.ovo_streaming_csc( + X_gpu.data.astype(cp.float32, copy=False), + X_gpu.indices.astype(cp.int32, copy=False), + X_gpu.indptr.astype(cp.int32, copy=False), + ref_row_map, + grp_row_map, + grp_offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + elif cpsp.issparse(X_gpu): + # CSR-native: extract ref/grp rows directly + csr_gpu = X_gpu.tocsr() if not cpsp.isspmatrix_csr(X_gpu) else X_gpu + _wc.ovo_streaming_csr( + csr_gpu.data.astype(cp.float32, copy=False), + csr_gpu.indices.astype(cp.int32, copy=False), + csr_gpu.indptr.astype(cp.int32, copy=False), + ref_row_ids_gpu, + all_grp_row_ids_gpu, + grp_offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + else: + # Dense: extract blocks, sort, stream + ref_block = _extract_dense_block(X_gpu, ref_row_ids_gpu, 0, n_total_genes) + grp_block = _extract_dense_block( + X_gpu, all_grp_row_ids_gpu, 0, n_total_genes + ) + ref_sorted = _segmented_sort_columns( + ref_block, + np.array([0, n_ref], dtype=np.int32), + n_ref, + n_total_genes, + 1, + ) + grp_f32 = cp.asfortranarray(grp_block.astype(cp.float32, copy=False)) + _wc.ovo_streaming( + ref_sorted, + grp_f32, + grp_offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + ) # ---- z-scores & p-values (vectorised) ---- n_combined = test_sizes + n_ref @@ -502,83 +626,3 @@ def _wilcoxon_with_reference( all_p = p_values.get() return [(gi, all_z[ti], all_p[ti]) for ti, gi in enumerate(test_group_indices)] - - -def _compute_grouped_stats( - rg: _RankGenes, - ireference: int, - ref_block: cp.ndarray, - n_ref: int, - *, - test_group_indices: list[int], - grp_block: cp.ndarray, - grp_offsets_gpu: cp.ndarray, - n_test: int, - n_cols: int, - start: int, - stop: int, -) -> None: - """Compute mean/var/pts for ref + all test groups via fused C++ kernel.""" - s = slice(start, stop) - stream = cp.cuda.get_current_stream().ptr - - # Reference stats (single "group") - ref_offsets = cp.array([0, n_ref], dtype=cp.int32) - ref_sums = cp.empty((1, n_cols), dtype=cp.float64) - ref_sq = cp.empty((1, n_cols), dtype=cp.float64) - ref_nnz = cp.empty((1, n_cols), dtype=cp.float64) - _wc.grouped_stats( - ref_block, - ref_offsets, - ref_sums, - ref_sq, - ref_nnz, - n_all_rows=n_ref, - n_cols=n_cols, - n_groups=1, - compute_nnz=rg.comp_pts, - stream=stream, - ) - - rg.means[ireference, s] = cp.asnumpy(ref_sums[0] / n_ref) - if n_ref > 1: - var = (ref_sq[0] - ref_sums[0] ** 2 / n_ref) / (n_ref - 1) - rg.vars[ireference, s] = cp.asnumpy(cp.maximum(var, 0)) - if rg.comp_pts: - rg.pts[ireference, s] = cp.asnumpy(ref_nnz[0] / n_ref) - - # All test groups in one kernel launch - n_all_grp = grp_block.shape[0] - grp_sums = cp.empty((n_test, n_cols), dtype=cp.float64) - grp_sq = cp.empty((n_test, n_cols), dtype=cp.float64) - grp_nnz = cp.empty((n_test, n_cols), dtype=cp.float64) - _wc.grouped_stats( - grp_block, - grp_offsets_gpu, - grp_sums, - grp_sq, - grp_nnz, - n_all_rows=n_all_grp, - n_cols=n_cols, - n_groups=n_test, - compute_nnz=rg.comp_pts, - stream=stream, - ) - - # Vectorised mean/var computation on GPU, single D2H transfer - sizes = cp.asarray( - [rg.group_sizes[gi] for gi in test_group_indices], dtype=cp.float64 - )[:, None] - means = grp_sums / sizes - vars_ = cp.maximum((grp_sq - grp_sums**2 / sizes) / cp.maximum(sizes - 1, 1), 0) - - means_host = cp.asnumpy(means) - vars_host = cp.asnumpy(vars_) - for ti, gi in enumerate(test_group_indices): - rg.means[gi, s] = means_host[ti] - rg.vars[gi, s] = vars_host[ti] - - if rg.comp_pts: - pts_host = cp.asnumpy(grp_nnz / sizes) - for ti, gi in enumerate(test_group_indices): - rg.pts[gi, s] = pts_host[ti] From 0335c0b538cbda1c29b13d63430a7cf9bb9a4663 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 13 Apr 2026 19:33:29 +0200 Subject: [PATCH 04/21] fix publish --- .github/workflows/publish.yml | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index d7ed8fd9..3ad607c1 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -74,11 +74,16 @@ jobs: 'name = "rapids-singlecell"', f'name = "rapids-singlecell-cu{cuda}"', ) - # Rename matching extra to "rapids", remove the other + # Rename matching extra to "rapids", remove the other CUDA extra text = text.replace(f'rapids-cu{cuda} =', 'rapids =') - # Remove the other CUDA extra line entirely - lines = text.splitlines(keepends=True) - text = "".join(l for l in lines if f'rapids-cu{other}' not in l) + # Remove the other CUDA extra (handles multi-line TOML arrays) + import re + text = re.sub( + rf'^rapids-cu{other}\s*=\s*\[.*?\]\s*\n', + '', + text, + flags=re.MULTILINE | re.DOTALL, + ) # Set CUDA architectures (replace "native" with CI target archs) text = text.replace( From 3fb11d98ef1e80250ec581f8a68604dfbdb4cb10 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 13 Apr 2026 19:37:32 +0200 Subject: [PATCH 05/21] try fix build wheel --- .github/workflows/publish.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 3ad607c1..a9f2e84b 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -117,12 +117,15 @@ jobs: CIBW_ENVIRONMENT_PASS_LINUX: SETUPTOOLS_SCM_PRETEND_VERSION CIBW_ENVIRONMENT: > CUDA_PATH=/usr/local/cuda - LD_LIBRARY_PATH=/usr/local/cuda/lib64:$(python3 -c "import sysconfig,os;sp=sysconfig.get_path('purelib');print(os.path.join(sp,'librmm','lib64')+':'+os.path.join(sp,'rapids_logger','lib64'))" 2>/dev/null || echo ""):$LD_LIBRARY_PATH PATH=/usr/local/cuda/bin:$PATH CIBW_BEFORE_BUILD: > python -m pip install -U pip scikit-build-core cmake ninja nanobind - librmm-cu${{ matrix.cuda_major }} + librmm-cu${{ matrix.cuda_major }} && + SITE=$(python -c "import sysconfig;print(sysconfig.get_path('purelib'))") && + ln -sf "$SITE/librmm/lib64/librmm.so" /usr/local/lib/librmm.so && + ln -sf "$SITE/rapids_logger/lib64/librapids_logger.so" /usr/local/lib/librapids_logger.so && + ldconfig CIBW_TEST_SKIP: "*" CIBW_TEST_COMMAND: "" CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}" From abddd30fdc5be7ca182c5da72ccefae4f4c0d900 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 13 Apr 2026 19:47:47 +0200 Subject: [PATCH 06/21] fix ld path --- .github/workflows/publish.yml | 3 ++- CMakeLists.txt | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index a9f2e84b..f0d660a4 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -125,7 +125,8 @@ jobs: SITE=$(python -c "import sysconfig;print(sysconfig.get_path('purelib'))") && ln -sf "$SITE/librmm/lib64/librmm.so" /usr/local/lib/librmm.so && ln -sf "$SITE/rapids_logger/lib64/librapids_logger.so" /usr/local/lib/librapids_logger.so && - ldconfig + ldconfig && + python -c "import librmm;print(librmm.__path__[0])" > /tmp/.librmm_dir CIBW_TEST_SKIP: "*" CIBW_TEST_COMMAND: "" CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}" diff --git a/CMakeLists.txt b/CMakeLists.txt index 3bf8fe14..d41ea611 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,6 +34,12 @@ if (RSC_BUILD_EXTENSIONS) if(_py_prefix) list(APPEND _env_roots "${_py_prefix}") endif() + # CI/cibuildwheel: CIBW_BEFORE_BUILD writes the librmm path to a marker file + if(EXISTS "/tmp/.librmm_dir") + file(READ "/tmp/.librmm_dir" _rmm_marker) + string(STRIP "${_rmm_marker}" _rmm_marker) + list(APPEND _env_roots "${_rmm_marker}/..") + endif() foreach(_root ${_env_roots}) file(GLOB _hints "${_root}/lib/cmake/rmm" "${_root}/lib/python*/site-packages/librmm/lib*/cmake/rmm" From 8a59ee16dc01d9f977df65a3b2278c9d356557e2 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 13 Apr 2026 20:02:39 +0200 Subject: [PATCH 07/21] update files --- CMakeLists.txt | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d41ea611..3b62a6c5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,10 +49,7 @@ if (RSC_BUILD_EXTENSIONS) list(APPEND CMAKE_PREFIX_PATH "${_dir}") endforeach() endforeach() - find_package(rmm CONFIG) - if(NOT rmm_FOUND) - message(WARNING "librmm not found — wilcoxon will use cudaMalloc fallback") - endif() + find_package(rmm CONFIG REQUIRED) message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") else() message(STATUS "RSC_BUILD_EXTENSIONS=OFF -> skipping compiled extensions for docs") From a7f8852be5f9cb530df4eac659cd621debdd557b Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 13 Apr 2026 20:28:58 +0200 Subject: [PATCH 08/21] add debugging --- .github/workflows/publish.yml | 6 +++++- CMakeLists.txt | 7 ++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index f0d660a4..676ec118 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -123,10 +123,14 @@ jobs: scikit-build-core cmake ninja nanobind librmm-cu${{ matrix.cuda_major }} && SITE=$(python -c "import sysconfig;print(sysconfig.get_path('purelib'))") && + echo "[rsc-build] site-packages=$SITE" && + echo "[rsc-build] librmm=$(ls $SITE/librmm/lib64/*.so 2>/dev/null)" && + echo "[rsc-build] rapids_logger=$(ls $SITE/rapids_logger/lib64/*.so 2>/dev/null)" && ln -sf "$SITE/librmm/lib64/librmm.so" /usr/local/lib/librmm.so && ln -sf "$SITE/rapids_logger/lib64/librapids_logger.so" /usr/local/lib/librapids_logger.so && ldconfig && - python -c "import librmm;print(librmm.__path__[0])" > /tmp/.librmm_dir + python -c "import librmm;print(librmm.__path__[0])" > /tmp/.librmm_dir && + echo "[rsc-build] marker=$(cat /tmp/.librmm_dir)" CIBW_TEST_SKIP: "*" CIBW_TEST_COMMAND: "" CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}" diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b62a6c5..4eb48600 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,10 @@ if (RSC_BUILD_EXTENSIONS) if(EXISTS "/tmp/.librmm_dir") file(READ "/tmp/.librmm_dir" _rmm_marker) string(STRIP "${_rmm_marker}" _rmm_marker) - list(APPEND _env_roots "${_rmm_marker}/..") + # Marker contains e.g. /opt/.../site-packages/librmm — find cmake dir + deps + file(GLOB _marker_hints "${_rmm_marker}/lib*/cmake" + "${_rmm_marker}/../rapids_logger/lib*/cmake") + list(APPEND CMAKE_PREFIX_PATH ${_marker_hints}) endif() foreach(_root ${_env_roots}) file(GLOB _hints "${_root}/lib/cmake/rmm" @@ -49,6 +52,8 @@ if (RSC_BUILD_EXTENSIONS) list(APPEND CMAKE_PREFIX_PATH "${_dir}") endforeach() endforeach() + message(STATUS "rmm search roots: ${_env_roots}") + message(STATUS "rmm CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}") find_package(rmm CONFIG REQUIRED) message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") else() From 1b926c0fbecf9c49193432cc965e2905f7bbb41c Mon Sep 17 00:00:00 2001 From: Intron7 Date: Tue, 14 Apr 2026 10:48:02 +0200 Subject: [PATCH 09/21] fix rmm discovery --- CMakeLists.txt | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4eb48600..f2af10ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,33 +16,40 @@ if (RSC_BUILD_EXTENSIONS) find_package(CUDAToolkit REQUIRED) # Find librmm cmake config. - # Works with conda, pixi, uv, venv — uses env root to find site-packages. - # Priority: LIBRMM_DIR env var > CONDA_PREFIX > VIRTUAL_ENV > Python prefix. + # Searches all plausible locations: env vars, Python prefix, base_prefix + # (survives build isolation), and CI marker files. Works with conda, pip, + # pixi, uv, hatch, and cibuildwheel. set(_env_roots "") + # Explicit override if(DEFINED ENV{LIBRMM_DIR}) list(APPEND _env_roots "$ENV{LIBRMM_DIR}/..") endif() + # Environment managers foreach(_var CONDA_PREFIX VIRTUAL_ENV PIXI_PROJECT_ROOT) if(DEFINED ENV{${_var}}) list(APPEND _env_roots "$ENV{${_var}}") endif() endforeach() - # Fallback: Python prefix (works for any env manager) - execute_process( - COMMAND "${Python_EXECUTABLE}" -c "import sys; print(sys.prefix)" - OUTPUT_VARIABLE _py_prefix OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) - if(_py_prefix) - list(APPEND _env_roots "${_py_prefix}") - endif() + # Python prefix + base_prefix (base_prefix survives build isolation — + # it points to the outer env even when pip/uv creates a temp venv) + foreach(_attr prefix base_prefix) + execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import sys; print(sys.${_attr})" + OUTPUT_VARIABLE _pp OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) + if(_pp) + list(APPEND _env_roots "${_pp}") + endif() + endforeach() # CI/cibuildwheel: CIBW_BEFORE_BUILD writes the librmm path to a marker file if(EXISTS "/tmp/.librmm_dir") file(READ "/tmp/.librmm_dir" _rmm_marker) string(STRIP "${_rmm_marker}" _rmm_marker) - # Marker contains e.g. /opt/.../site-packages/librmm — find cmake dir + deps file(GLOB _marker_hints "${_rmm_marker}/lib*/cmake" "${_rmm_marker}/../rapids_logger/lib*/cmake") list(APPEND CMAKE_PREFIX_PATH ${_marker_hints}) endif() + # Search each root for rmm + rapids_logger cmake configs + list(REMOVE_DUPLICATES _env_roots) foreach(_root ${_env_roots}) file(GLOB _hints "${_root}/lib/cmake/rmm" "${_root}/lib/python*/site-packages/librmm/lib*/cmake/rmm" From 327ffd85fc4b4801896e13c811b2f9bd62b93a51 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Tue, 14 Apr 2026 11:12:17 +0200 Subject: [PATCH 10/21] try rebuild --- CMakeLists.txt | 27 +++++++++++++++++++++++++-- pyproject.toml | 7 +++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f2af10ab..6b2d9da1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,8 +30,7 @@ if (RSC_BUILD_EXTENSIONS) list(APPEND _env_roots "$ENV{${_var}}") endif() endforeach() - # Python prefix + base_prefix (base_prefix survives build isolation — - # it points to the outer env even when pip/uv creates a temp venv) + # Python prefix, base_prefix, and real executable's env root foreach(_attr prefix base_prefix) execute_process( COMMAND "${Python_EXECUTABLE}" -c "import sys; print(sys.${_attr})" @@ -40,6 +39,30 @@ if (RSC_BUILD_EXTENSIONS) list(APPEND _env_roots "${_pp}") endif() endforeach() + # Resolve symlinks to find the real Python env (works through build isolation) + execute_process( + COMMAND "${Python_EXECUTABLE}" -c + "import sys,pathlib; print(pathlib.Path(sys.executable).resolve().parents[1])" + OUTPUT_VARIABLE _real_prefix OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) + if(_real_prefix) + list(APPEND _env_roots "${_real_prefix}") + endif() + # Direct site-packages search — the most reliable way to find pip-installed + # librmm regardless of venv nesting depth + execute_process( + COMMAND "${Python_EXECUTABLE}" -c + "import site; print(';'.join(site.getsitepackages()))" + OUTPUT_VARIABLE _site_paths OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) + if(_site_paths) + foreach(_sp ${_site_paths}) + file(GLOB _sp_hints "${_sp}/librmm/lib*/cmake" + "${_sp}/rapids_logger/lib*/cmake") + foreach(_h ${_sp_hints}) + get_filename_component(_dir "${_h}" DIRECTORY) + list(APPEND CMAKE_PREFIX_PATH "${_dir}") + endforeach() + endforeach() + endif() # CI/cibuildwheel: CIBW_BEFORE_BUILD writes the librmm path to a marker file if(EXISTS "/tmp/.librmm_dir") file(READ "/tmp/.librmm_dir" _rmm_marker) diff --git a/pyproject.toml b/pyproject.toml index 7449bacf..ba16f430 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -197,3 +197,10 @@ source = [ "./src", "**/site-packages" ] exclude_also = [ "if TYPE_CHECKING:", ] + +[tool.uv] +# librmm headers are needed at build time for the wilcoxon CUDA kernels. +# The headers are identical across cu12/cu13 — only the .so differs (loaded +# at runtime via librmm.load_library()). cu12 is used here as the build-time +# provider; cu13 envs get the same headers. +extra-build-dependencies = [ "librmm-cu12>=25.10" ] From 68ad73cbebe75118fc99117c458ef076f2e1ee98 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Tue, 14 Apr 2026 11:54:14 +0200 Subject: [PATCH 11/21] address issues --- .github/workflows/publish.yml | 12 +- CMakeLists.txt | 4 +- pyproject.toml | 4 +- src/rapids_singlecell/_cuda/nb_types.h | 2 +- .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 98 ++++++++++------ .../_cuda/wilcoxon/wilcoxon_ovo.cu | 48 +++++++- .../_cuda/wilcoxon/wilcoxon_ovr.cu | 110 +++++++++++++++--- .../tools/_rank_genes_groups/_wilcoxon.py | 6 +- 8 files changed, 219 insertions(+), 65 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 676ec118..12a86743 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -122,12 +122,12 @@ jobs: python -m pip install -U pip scikit-build-core cmake ninja nanobind librmm-cu${{ matrix.cuda_major }} && - SITE=$(python -c "import sysconfig;print(sysconfig.get_path('purelib'))") && - echo "[rsc-build] site-packages=$SITE" && - echo "[rsc-build] librmm=$(ls $SITE/librmm/lib64/*.so 2>/dev/null)" && - echo "[rsc-build] rapids_logger=$(ls $SITE/rapids_logger/lib64/*.so 2>/dev/null)" && - ln -sf "$SITE/librmm/lib64/librmm.so" /usr/local/lib/librmm.so && - ln -sf "$SITE/rapids_logger/lib64/librapids_logger.so" /usr/local/lib/librapids_logger.so && + RMM_ROOT=$(python -c "import librmm;print(librmm.__path__[0])") && + LOG_ROOT=$(python -c "import rapids_logger;print(rapids_logger.__path__[0])") && + echo "[rsc-build] librmm=$RMM_ROOT" && + echo "[rsc-build] rapids_logger=$LOG_ROOT" && + ln -sf "$RMM_ROOT/lib64/librmm.so" /usr/local/lib/librmm.so && + ln -sf "$LOG_ROOT/lib64/librapids_logger.so" /usr/local/lib/librapids_logger.so && ldconfig && python -c "import librmm;print(librmm.__path__[0])" > /tmp/.librmm_dir && echo "[rsc-build] marker=$(cat /tmp/.librmm_dir)" diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b2d9da1..8e987b11 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,9 @@ if (RSC_BUILD_EXTENSIONS) set(_env_roots "") # Explicit override if(DEFINED ENV{LIBRMM_DIR}) - list(APPEND _env_roots "$ENV{LIBRMM_DIR}/..") + file(GLOB _librmm_hints "$ENV{LIBRMM_DIR}/lib*/cmake" + "$ENV{LIBRMM_DIR}/../rapids_logger/lib*/cmake") + list(APPEND CMAKE_PREFIX_PATH ${_librmm_hints}) endif() # Environment managers foreach(_var CONDA_PREFIX VIRTUAL_ENV PIXI_PROJECT_ROOT) diff --git a/pyproject.toml b/pyproject.toml index ba16f430..f45a5005 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -203,4 +203,6 @@ exclude_also = [ # The headers are identical across cu12/cu13 — only the .so differs (loaded # at runtime via librmm.load_library()). cu12 is used here as the build-time # provider; cu13 envs get the same headers. -extra-build-dependencies = [ "librmm-cu12>=25.10" ] + +[tool.uv.extra-build-dependencies] +rapids-singlecell = [ "librmm-cu12>=25.10" ] diff --git a/src/rapids_singlecell/_cuda/nb_types.h b/src/rapids_singlecell/_cuda/nb_types.h index 4cb10e44..eb343815 100644 --- a/src/rapids_singlecell/_cuda/nb_types.h +++ b/src/rapids_singlecell/_cuda/nb_types.h @@ -47,7 +47,7 @@ template using host_array = nb::ndarray>; template -using host_array_2d = nb::ndarray; +using host_array_2d = nb::ndarray>; // Register bindings for both regular CUDA and managed-memory arrays. // Usage: diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index 46ff14f0..44a614af 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -10,37 +10,60 @@ * tie groups by adjacent comparison (sequential access, no binary search). * Cross-boundary ties are resolved via binary search at chunk boundaries. * - * Used by the OVR streaming pipeline in wilcoxon_streaming.cu. + * When use_gmem is false (default), per-group accumulators live in shared + * memory (fast atomics, limited to ~750 groups on 48 KB devices). + * When use_gmem is true, accumulators write directly to the output arrays + * in global memory, supporting an arbitrary number of groups. The caller + * must pre-zero rank_sums (and group_sums/group_sq_sums/group_nnz if + * compute_stats) before launching. + * + * Shared memory layout: + * use_gmem=false: (4 * n_groups + 32) doubles (accumulators + warp buf) + * use_gmem=true: 32 doubles (warp buf only) */ __global__ void rank_sums_from_sorted_kernel( - const float* __restrict__ sorted_vals, // F-order (n_rows, n_cols) - const int* __restrict__ sorted_row_idx, // F-order (n_rows, n_cols) - const int* __restrict__ group_codes, // (n_rows_total,) - double* __restrict__ rank_sums, // (n_groups, n_cols) row-major - double* __restrict__ tie_corr, // (n_cols,) - double* __restrict__ group_sums, // (n_groups, n_cols) or NULL - double* __restrict__ group_sq_sums, // (n_groups, n_cols) or NULL - double* __restrict__ group_nnz, // (n_groups, n_cols) or NULL - int n_rows, int n_cols, int n_groups, bool compute_tie_corr, - bool compute_stats) { + const float* __restrict__ sorted_vals, + const int* __restrict__ sorted_row_idx, const int* __restrict__ group_codes, + double* __restrict__ rank_sums, double* __restrict__ tie_corr, + double* __restrict__ group_sums, double* __restrict__ group_sq_sums, + double* __restrict__ group_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_stats, bool use_gmem) { int col = blockIdx.x; if (col >= n_cols) return; extern __shared__ double smem[]; - double* grp_sums = smem; - double* s_sum = smem + n_groups; - double* s_sq = smem + 2 * n_groups; - double* s_nnz = smem + 3 * n_groups; - - for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { - grp_sums[g] = 0.0; - if (compute_stats) { - s_sum[g] = 0.0; - s_sq[g] = 0.0; - s_nnz[g] = 0.0; + + // Accumulator pointers: shared memory (fast) or global memory (large + // groups) + double* grp_sums; + double* s_sum; + double* s_sq; + double* s_nnz; + + if (use_gmem) { + // Global memory path: write directly to output arrays (must be + // pre-zeroed) + grp_sums = rank_sums + (size_t)col; // stride: n_cols + s_sum = group_sums ? group_sums + (size_t)col : nullptr; + s_sq = group_sq_sums ? group_sq_sums + (size_t)col : nullptr; + s_nnz = group_nnz ? group_nnz + (size_t)col : nullptr; + } else { + // Shared memory path: per-block accumulators + grp_sums = smem; + s_sum = smem + n_groups; + s_sq = smem + 2 * n_groups; + s_nnz = smem + 3 * n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; + if (compute_stats) { + s_sum[g] = 0.0; + s_sq[g] = 0.0; + s_nnz[g] = 0.0; + } } + __syncthreads(); } - __syncthreads(); const float* sv = sorted_vals + (size_t)col * n_rows; const int* si = sorted_row_idx + (size_t)col * n_rows; @@ -52,6 +75,9 @@ __global__ void rank_sums_from_sorted_kernel( double local_tie_sum = 0.0; + // Stride for accumulator indexing: 1 for shared mem, n_cols for global mem + int acc_stride = use_gmem ? n_cols : 1; + int i = my_start; while (i < my_end) { double val = sv[i]; @@ -93,12 +119,12 @@ __global__ void rank_sums_from_sorted_kernel( for (int j = i; j < tie_local_end; ++j) { int grp = group_codes[si[j]]; if (grp < n_groups) { - atomicAdd(&grp_sums[grp], avg_rank); + atomicAdd(&grp_sums[grp * acc_stride], avg_rank); if (compute_stats) { double v = (double)sv[j]; - atomicAdd(&s_sum[grp], v); - atomicAdd(&s_sq[grp], v * v); - if (v != 0.0) atomicAdd(&s_nnz[grp], 1.0); + atomicAdd(&s_sum[grp * acc_stride], v); + atomicAdd(&s_sq[grp * acc_stride], v * v); + if (v != 0.0) atomicAdd(&s_nnz[grp * acc_stride], 1.0); } } } @@ -113,17 +139,21 @@ __global__ void rank_sums_from_sorted_kernel( __syncthreads(); - for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { - rank_sums[(size_t)g * n_cols + col] = grp_sums[g]; - if (compute_stats) { - group_sums[(size_t)g * n_cols + col] = s_sum[g]; - group_sq_sums[(size_t)g * n_cols + col] = s_sq[g]; - group_nnz[(size_t)g * n_cols + col] = s_nnz[g]; + // Copy shared memory accumulators to global output (smem path only) + if (!use_gmem) { + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * n_cols + col] = grp_sums[g]; + if (compute_stats) { + group_sums[(size_t)g * n_cols + col] = s_sum[g]; + group_sq_sums[(size_t)g * n_cols + col] = s_sq[g]; + group_nnz[(size_t)g * n_cols + col] = s_nnz[g]; + } } } if (compute_tie_corr) { - double* warp_buf = smem + n_groups; + // Warp buf always in shared memory (32 doubles — always fits) + double* warp_buf = use_gmem ? smem : smem + n_groups; #pragma unroll for (int off = 16; off > 0; off >>= 1) local_tie_sum += __shfl_down_sync(0xffffffff, local_tie_sum, off); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu index 1f64ed53..bc0d1119 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu @@ -181,7 +181,13 @@ static void ovo_streaming_impl(const float* ref_sorted, const float* grp_data, batch_idx++; } - for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -383,7 +389,13 @@ static void ovo_streaming_csr_impl( batch_idx++; } - for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -571,7 +583,13 @@ static void ovo_streaming_csc_impl( batch_idx++; } - for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -808,7 +826,13 @@ static void ovo_streaming_csc_host_impl( batch_idx++; } - for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } cudaHostUnregister(const_cast(h_data)); cudaHostUnregister(const_cast(h_indices)); @@ -1040,7 +1064,13 @@ static void ovo_streaming_csr_host_impl( batch_idx++; } - for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } cudaHostUnregister(const_cast(h_data)); cudaHostUnregister(const_cast(h_indices)); @@ -1272,7 +1302,13 @@ static void ovo_streaming_dense_host_impl( batch_idx++; } - for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } cudaHostUnregister(const_cast(h_block)); cudaHostUnregister(h_rank_sums); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu index 7caa29bf..be10a487 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu @@ -9,6 +9,28 @@ using namespace nb::literals; +/** + * Decide whether to use shared or global memory for OVR rank accumulators. + * Returns the smem size to request and sets use_gmem accordingly. + */ +static size_t ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(4 * n_groups + 32) * sizeof(double); + static int max_smem = -1; + if (max_smem < 0) { + int device; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlock, + device); + } + if ((int)need <= max_smem) { + use_gmem = false; + return need; + } + // Fall back to global memory accumulators; only need warp buf in smem + use_gmem = true; + return 32 * sizeof(double); +} + /** * Extract dense F-order float32 block from CSC. * Column range [col_start, col_stop). @@ -120,7 +142,8 @@ static void ovr_streaming_impl(const float* block, const int* group_codes, } int tpb_rank = round_up_to_warp(n_rows); - int smem_rank = (4 * n_groups + 32) * sizeof(double); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); // Process sub-batches round-robin across streams int col = 0; @@ -146,10 +169,15 @@ static void ovr_streaming_impl(const float* block, const int* group_codes, buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); // Fused rank sums into sub-batch buffer + if (use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } rank_sums_from_sorted_kernel<<>>( buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, buf.sub_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, - n_groups, compute_tie_corr, false); + n_groups, compute_tie_corr, false, use_gmem); // Copy sub-batch results to global output (row-major scatter) // rank_sums is (n_groups, n_cols) row-major: group g, col c → @@ -171,7 +199,11 @@ static void ovr_streaming_impl(const float* block, const int* group_codes, // Sync all streams for (int s = 0; s < n_streams; s++) { - cudaStreamSynchronize(streams[s]); + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); } for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); @@ -233,7 +265,8 @@ static void ovr_streaming_csr_impl( } int tpb_rank = round_up_to_warp(n_rows); - int smem_rank = (4 * n_groups + 32) * sizeof(double); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); int col = 0; int batch_idx = 0; @@ -264,10 +297,15 @@ static void ovr_streaming_csr_impl( buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); // Fused rank sums + if (use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } rank_sums_from_sorted_kernel<<>>( buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, buf.sub_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, - n_groups, compute_tie_corr, false); + n_groups, compute_tie_corr, false, use_gmem); // Scatter to global output cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), @@ -284,7 +322,13 @@ static void ovr_streaming_csr_impl( batch_idx++; } - for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -345,7 +389,8 @@ static void ovr_streaming_csc_impl( } int tpb_rank = round_up_to_warp(n_rows); - int smem_rank = (4 * n_groups + 32) * sizeof(double); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); int col = 0; int batch_idx = 0; @@ -376,10 +421,15 @@ static void ovr_streaming_csc_impl( buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); // Fused rank sums + if (use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } rank_sums_from_sorted_kernel<<>>( buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, buf.sub_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, - n_groups, compute_tie_corr, false); + n_groups, compute_tie_corr, false, use_gmem); // Scatter to global output cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), @@ -396,7 +446,13 @@ static void ovr_streaming_csc_impl( batch_idx++; } - for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -477,7 +533,8 @@ static void ovr_streaming_csc_host_impl( cudaMemcpyHostToDevice); int tpb_rank = round_up_to_warp(n_rows); - int smem_rank = (4 * n_groups + 32) * sizeof(double); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); // Pin host memory for async transfers cudaHostRegister(const_cast(h_data), @@ -538,10 +595,15 @@ static void ovr_streaming_csc_host_impl( buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); // Fused rank sums + if (use_gmem) { + cudaMemsetAsync(buf.d_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } rank_sums_from_sorted_kernel<<>>( buf.keys_out, buf.vals_out, d_group_codes, buf.d_rank_sums, buf.d_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, - n_groups, compute_tie_corr, false); + n_groups, compute_tie_corr, false, use_gmem); // D2H: scatter results to host output cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), @@ -558,7 +620,13 @@ static void ovr_streaming_csc_host_impl( batch_idx++; } - for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } cudaHostUnregister(const_cast(h_data)); cudaHostUnregister(const_cast(h_indices)); @@ -628,7 +696,8 @@ static void ovr_streaming_dense_host_impl( cudaMemcpyHostToDevice); int tpb_rank = round_up_to_warp(n_rows); - int smem_rank = (4 * n_groups + 32) * sizeof(double); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); // Pin host memory cudaHostRegister(const_cast(h_block), @@ -664,10 +733,15 @@ static void ovr_streaming_dense_host_impl( buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); // Fused rank sums + if (use_gmem) { + cudaMemsetAsync(buf.d_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } rank_sums_from_sorted_kernel<<>>( buf.keys_out, buf.vals_out, d_group_codes, buf.d_rank_sums, buf.d_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, - n_groups, compute_tie_corr, false); + n_groups, compute_tie_corr, false, use_gmem); // D2H: scatter results cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), @@ -684,7 +758,13 @@ static void ovr_streaming_dense_host_impl( batch_idx++; } - for (int s = 0; s < n_streams; s++) cudaStreamSynchronize(streams[s]); + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } cudaHostUnregister(const_cast(h_block)); cudaHostUnregister(h_rank_sums); diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 5c36585e..cdec0d7b 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -164,7 +164,8 @@ def _segmented_sort_columns( temp_bytes = _wc.get_seg_sort_temp_bytes(n_items=n_items, n_segments=n_segments) cub_temp = cp.empty(temp_bytes, dtype=cp.uint8) - keys_in = cp.ascontiguousarray(data.astype(cp.float32).ravel(order="F")) + # data is F-order; ravel("F") gives a flat C-contiguous view (no copy) + keys_in = data.astype(cp.float32, copy=False).ravel(order="F") keys_out = cp.empty_like(keys_in) _wc.segmented_sort( @@ -425,6 +426,9 @@ def _wilcoxon_with_reference( all_grp_rows.append(rows) offsets.append(offsets[-1] + len(rows)) + if not test_group_indices: + return [] + all_grp_row_ids_np = np.concatenate(all_grp_rows) grp_offsets_gpu = cp.asarray(offsets, dtype=cp.int32) n_test = len(test_group_indices) From 3d2e563ee71fe99e43ca8ccb95a6e21dce1dd750 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Tue, 14 Apr 2026 17:55:33 +0200 Subject: [PATCH 12/21] update fixes --- .github/workflows/publish.yml | 10 +++++++++- pyproject.toml | 2 ++ .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 7 +++++-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 12a86743..c506c972 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -133,7 +133,15 @@ jobs: echo "[rsc-build] marker=$(cat /tmp/.librmm_dir)" CIBW_TEST_SKIP: "*" CIBW_TEST_COMMAND: "" - CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}" + CIBW_REPAIR_WHEEL_COMMAND: > + auditwheel repair + --exclude libcublas.so.${{ matrix.cuda_major }} + --exclude libcublasLt.so.${{ matrix.cuda_major }} + --exclude libcudart.so.${{ matrix.cuda_major }} + --exclude librmm.so + --exclude librapids_logger.so + -w {dest_dir} {wheel} + && pipx run abi3audit --strict --report {wheel} CIBW_BUILD_VERBOSITY: "1" - uses: actions/upload-artifact@v4 diff --git a/pyproject.toml b/pyproject.toml index f45a5005..7903bcf1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -203,6 +203,8 @@ exclude_also = [ # The headers are identical across cu12/cu13 — only the .so differs (loaded # at runtime via librmm.load_library()). cu12 is used here as the build-time # provider; cu13 envs get the same headers. +# NOTE: This is uv-specific. pip users building from source need: +# pip install librmm-cu12 && pip install --no-build-isolation -e . [tool.uv.extra-build-dependencies] rapids-singlecell = [ "librmm-cu12>=25.10" ] diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index 44a614af..18cd03cf 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -152,8 +152,11 @@ __global__ void rank_sums_from_sorted_kernel( } if (compute_tie_corr) { - // Warp buf always in shared memory (32 doubles — always fits) - double* warp_buf = use_gmem ? smem : smem + n_groups; + // Warp buf sits after all accumulator arrays in shared memory. + // gmem path: accumulators are in global mem, warp buf starts at + // smem[0]. smem path: 4 arrays of n_groups doubles, then warp buf. + int warp_buf_off = use_gmem ? 0 : (compute_stats ? 4 : 1) * n_groups; + double* warp_buf = smem + warp_buf_off; #pragma unroll for (int off = 16; off > 0; off >>= 1) local_tie_sum += __shfl_down_sync(0xffffffff, local_tie_sum, off); From 12e3bf44fabea40f2beb4802908ef878993e6bf3 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Tue, 14 Apr 2026 18:45:59 +0200 Subject: [PATCH 13/21] add more tests --- tests/test_rank_genes_groups_wilcoxon.py | 521 ++++++++++++++++++ .../test_rank_genes_groups_wilcoxon_binned.py | 217 ++++++++ 2 files changed, 738 insertions(+) diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 455441e1..3977fb35 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -1,5 +1,7 @@ from __future__ import annotations +import cupy as cp +import cupyx.scipy.sparse as cpsp import numpy as np import pandas as pd import pytest @@ -10,6 +12,19 @@ import rapids_singlecell as rsc +def _to_format(X_dense, fmt): + """Convert dense numpy array to the specified format.""" + if fmt == "scipy_csc": + return sp.csc_matrix(X_dense) + if fmt == "cupy_dense": + return cp.asarray(X_dense) + if fmt == "cupy_csr": + return cpsp.csr_matrix(cp.asarray(X_dense)) + if fmt == "cupy_csc": + return cpsp.csc_matrix(cp.asarray(X_dense)) + raise ValueError(f"Unknown format: {fmt}") + + @pytest.mark.parametrize("reference", ["rest", "1"]) @pytest.mark.parametrize("tie_correct", [True, False]) @pytest.mark.parametrize("sparse", [True, False]) @@ -440,3 +455,509 @@ def test_sparse_matches_dense(self, perturbation_adata, sparse): np.testing.assert_array_equal( dense_df["pvals"].values, sparse_df["pvals"].values ) + + +# ============================================================================ +# Matrix format coverage: all dispatch paths must agree +# ============================================================================ + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize( + "fmt", + [ + pytest.param("scipy_csc", id="scipy_csc"), + pytest.param("cupy_dense", id="cupy_dense"), + pytest.param("cupy_csr", id="cupy_csr"), + pytest.param("cupy_csc", id="cupy_csc"), + ], +) +def test_format_matches_scanpy(reference, fmt): + """Every matrix format matches scanpy output.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + + assert gpu_result["names"].dtype.names == cpu_result["names"].dtype.names + for group in gpu_result["names"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + + for field in ("scores", "pvals", "logfoldchanges", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) + + +# ============================================================================ +# pre_load: GPU transfer before wilcoxon must match default (lazy transfer) +# ============================================================================ + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_pre_load_matches_scanpy(reference): + """pre_load=True matches scanpy output.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw, pre_load=True) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + + assert gpu_result["names"].dtype.names == cpu_result["names"].dtype.names + for group in gpu_result["names"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + + for field in ("scores", "pvals", "logfoldchanges", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) + + +# ============================================================================ +# use_continuity with reference="rest" (OVR mode) +# ============================================================================ + + +def test_use_continuity_vs_rest_changes_scores(): + """use_continuity with reference='rest' adjusts z-scores toward zero.""" + np.random.seed(42) + adata_no = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + adata_no.obs["blobs"] = adata_no.obs["blobs"].astype("category") + adata_yes = adata_no.copy() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": "rest", + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata_no, **kw, use_continuity=False) + rsc.tl.rank_genes_groups(adata_yes, **kw, use_continuity=True) + + for group in adata_no.uns["rank_genes_groups"]["scores"].dtype.names: + scores_no = np.asarray( + adata_no.uns["rank_genes_groups"]["scores"][group], dtype=float + ) + scores_yes = np.asarray( + adata_yes.uns["rank_genes_groups"]["scores"][group], dtype=float + ) + # Continuity correction moves z-scores toward zero + assert np.all(np.abs(scores_yes) <= np.abs(scores_no) + 1e-15), ( + f"Group {group}: continuity-corrected |z| should be <= uncorrected |z|" + ) + # p-values should be valid + pvals = np.asarray( + adata_yes.uns["rank_genes_groups"]["pvals"][group], dtype=float + ) + assert np.all(pvals >= 0) + assert np.all(pvals <= 1) + + +# ============================================================================ +# mask_var: gene subsetting +# ============================================================================ + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_mask_var_matches_scanpy(reference): + """mask_var restricts genes and matches scanpy output.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=200) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + + mask = np.zeros(adata_gpu.n_vars, dtype=bool) + mask[:5] = True + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "mask_var": mask, + "reference": reference, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + + for group in gpu_result["names"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + names = set(gpu_result["names"][group]) + expected_genes = set(adata_gpu.var_names[:5]) + assert names <= expected_genes + + for field in ("scores", "pvals", "logfoldchanges", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) + + +def test_mask_var_string_key(): + """mask_var accepts a string key from adata.var.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=200) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + + adata_gpu.var["highly_variable"] = [True] * 5 + [False] * 5 + adata_cpu.var["highly_variable"] = [True] * 5 + [False] * 5 + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "mask_var": "highly_variable", + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + + for group in gpu_result["names"].dtype.names: + assert len(gpu_result["names"][group]) == 5 + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + + for field in ("scores", "pvals"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) + + +# ============================================================================ +# key_added: custom output key +# ============================================================================ + + +def test_key_added_matches_scanpy(): + """key_added stores results under a custom key, matching scanpy.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + + rsc.tl.rank_genes_groups( + adata_gpu, "blobs", method="wilcoxon", use_raw=False, key_added="my_de" + ) + sc.tl.rank_genes_groups( + adata_cpu, "blobs", method="wilcoxon", use_raw=False, key_added="my_de" + ) + + assert "my_de" in adata_gpu.uns + assert "rank_genes_groups" not in adata_gpu.uns + + gpu_result = adata_gpu.uns["my_de"] + cpu_result = adata_cpu.uns["my_de"] + + for field in ("scores", "pvals", "logfoldchanges", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) + + +# ============================================================================ +# rankby_abs: ranking by absolute score +# ============================================================================ + + +def test_rankby_abs_matches_scanpy(): + """rankby_abs ranks genes by |score| and matches scanpy.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "n_genes": 3, + "rankby_abs": True, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + + for group in gpu_result["names"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + + # Scores should be sorted by absolute value (descending) + abs_scores = np.abs(np.asarray(gpu_result["scores"][group], dtype=float)) + assert np.all(abs_scores[:-1] >= abs_scores[1:]), ( + f"Group {group}: abs scores not sorted descending" + ) + + for field in ("scores", "pvals", "logfoldchanges", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) + + +# ============================================================================ +# Small group warning +# ============================================================================ + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_small_group_warning(reference): + """Groups with <=25 cells trigger a RuntimeWarning.""" + np.random.seed(42) + n_large = 100 + n_small = 20 # below MIN_GROUP_SIZE_WARNING = 25 + n_cells = n_large + n_small + n_large + + adata = sc.AnnData( + X=np.random.randn(n_cells, 5).astype(np.float32), + obs=pd.DataFrame( + { + "group": pd.Categorical( + ["0"] * n_large + ["1"] * n_small + ["2"] * n_large, + categories=["0", "1", "2"], + ), + } + ), + ) + + with pytest.warns(RuntimeWarning, match="normal approximation"): + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference=reference, + ) + + +# ============================================================================ +# Singlet group rejection +# ============================================================================ + + +def test_singlet_group_raises(): + """A group with only 1 cell raises ValueError.""" + np.random.seed(42) + adata = sc.AnnData( + X=np.random.randn(101, 5).astype(np.float32), + obs=pd.DataFrame( + { + "group": pd.Categorical( + ["big"] * 100 + ["tiny"], + categories=["big", "tiny"], + ), + } + ), + ) + + with pytest.raises(ValueError, match="only contain one sample"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) + + +# ============================================================================ +# Invalid reference raises +# ============================================================================ + + +def test_invalid_reference_raises(): + """A reference not in the categories raises ValueError.""" + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=100) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + + with pytest.raises(ValueError, match="reference = nonexistent"): + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon", + use_raw=False, + reference="nonexistent", + ) + + +# ============================================================================ +# String group parameter raises +# ============================================================================ + + +def test_string_groups_raises(): + """Passing a bare string as groups raises ValueError.""" + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=100) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + + with pytest.raises(ValueError, match="Specify a sequence"): + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon", + use_raw=False, + groups="0", + ) + + +# ============================================================================ +# Many groups with reference: 5+ groups, one vs reference +# ============================================================================ + + +def test_many_groups_with_reference(): + """Wilcoxon with many test groups vs a reference produces correct output.""" + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=6, n_centers=5, n_observations=300) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + adata_cpu = adata.copy() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": "0", + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + result = adata.uns["rank_genes_groups"] + assert "0" not in result["names"].dtype.names + expected = {"1", "2", "3", "4"} + assert set(result["names"].dtype.names) == expected + + for field in ("scores", "pvals"): + for group in result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(result[field][group], dtype=float), + np.asarray( + adata_cpu.uns["rank_genes_groups"][field][group], dtype=float + ), + rtol=1e-13, + atol=1e-15, + ) + + +# ============================================================================ +# Group subsetting with unselected cells (OVR): unselected cells in "rest" +# ============================================================================ + + +def test_group_subset_vs_rest_unselected_cells(): + """With groups subset and reference='rest', unselected cells go into rest.""" + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=200) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + adata_cpu = adata.copy() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "groups": ["0", "2"], + "reference": "rest", + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + for field in ("scores", "pvals"): + for group in adata.uns["rank_genes_groups"][field].dtype.names: + np.testing.assert_allclose( + np.asarray(adata.uns["rank_genes_groups"][field][group], dtype=float), + np.asarray( + adata_cpu.uns["rank_genes_groups"][field][group], dtype=float + ), + rtol=1e-13, + atol=1e-15, + ) + + +# ============================================================================ +# Group subsetting with reference: unselected cells excluded from pairwise +# ============================================================================ + + +def test_group_subset_with_reference_unselected_cells(): + """With groups subset and a reference, unselected cells are excluded.""" + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=200) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + adata_cpu = adata.copy() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "groups": ["0", "1", "2"], + "reference": "1", + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + result = adata.uns["rank_genes_groups"] + assert "1" not in result["names"].dtype.names + assert set(result["names"].dtype.names) == {"0", "2"} + + for field in ("scores", "pvals"): + for group in result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(result[field][group], dtype=float), + np.asarray( + adata_cpu.uns["rank_genes_groups"][field][group], dtype=float + ), + rtol=1e-13, + atol=1e-15, + ) diff --git a/tests/test_rank_genes_groups_wilcoxon_binned.py b/tests/test_rank_genes_groups_wilcoxon_binned.py index 96af0711..1d2e469b 100644 --- a/tests/test_rank_genes_groups_wilcoxon_binned.py +++ b/tests/test_rank_genes_groups_wilcoxon_binned.py @@ -524,3 +524,220 @@ def test_top_genes_match_scipy(adata_blobs): scipy_top = set(adata_blobs.var_names[np.argsort(pvals)[:n_top]]) overlap = len(binned_top & scipy_top) assert overlap >= n_top - 1, f"Group {group}: {overlap}/{n_top} overlap" + + +# ============================================================================ +# tie_correct and use_continuity coverage +# ============================================================================ + + +class TestWilcoxonBinnedCorrections: + """Tests for tie_correct and use_continuity branches.""" + + def test_tie_correct_changes_scores(self, adata_blobs): + """tie_correct=True should produce different scores than False on tied data.""" + # Create data with heavy ties (integer counts) + rng = np.random.default_rng(42) + adata = adata_blobs.copy() + adata.X = rng.poisson(lam=3.0, size=adata.X.shape).astype(np.float32) + rsc.get.anndata_to_GPU(adata) + adata_tc = adata.copy() + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + tie_correct=False, + ) + rsc.tl.rank_genes_groups( + adata_tc, + "blobs", + method="wilcoxon_binned", + use_raw=False, + tie_correct=True, + ) + + # Scores should differ (tie correction adjusts variance) + for group in adata.uns["rank_genes_groups"]["scores"].dtype.names: + scores_no = np.asarray( + adata.uns["rank_genes_groups"]["scores"][group], dtype=float + ) + scores_tc = np.asarray( + adata_tc.uns["rank_genes_groups"]["scores"][group], dtype=float + ) + assert not np.allclose(scores_no, scores_tc, rtol=1e-10), ( + f"Group {group}: tie_correct had no effect on scores" + ) + + def test_use_continuity_changes_scores(self, adata_blobs): + """use_continuity=True should produce different scores than False.""" + adata = adata_blobs.copy() + rsc.get.anndata_to_GPU(adata) + adata_cont = adata.copy() + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + use_continuity=False, + ) + rsc.tl.rank_genes_groups( + adata_cont, + "blobs", + method="wilcoxon_binned", + use_raw=False, + use_continuity=True, + ) + + for group in adata.uns["rank_genes_groups"]["scores"].dtype.names: + scores_no = np.asarray( + adata.uns["rank_genes_groups"]["scores"][group], dtype=float + ) + scores_cont = np.asarray( + adata_cont.uns["rank_genes_groups"]["scores"][group], dtype=float + ) + assert not np.allclose(scores_no, scores_cont, rtol=1e-10), ( + f"Group {group}: use_continuity had no effect on scores" + ) + + @pytest.mark.parametrize("reference", ["rest", "1"]) + def test_tie_correct_with_reference(self, adata_blobs, reference): + """tie_correct works for both OVR and OVO in binned wilcoxon.""" + adata = adata_blobs.copy() + rsc.get.anndata_to_GPU(adata) + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + reference=reference, + tie_correct=True, + ) + + result = adata.uns["rank_genes_groups"] + for group in result["pvals"].dtype.names: + pvals = np.asarray(result["pvals"][group], dtype=float) + assert np.all(pvals >= 0) + assert np.all(pvals <= 1) + assert np.all(np.isfinite(pvals)) + + @pytest.mark.parametrize("reference", ["rest", "1"]) + def test_use_continuity_with_reference(self, adata_blobs, reference): + """use_continuity works for both OVR and OVO in binned wilcoxon.""" + adata = adata_blobs.copy() + rsc.get.anndata_to_GPU(adata) + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + reference=reference, + use_continuity=True, + ) + + result = adata.uns["rank_genes_groups"] + for group in result["pvals"].dtype.names: + pvals = np.asarray(result["pvals"][group], dtype=float) + assert np.all(pvals >= 0) + assert np.all(pvals <= 1) + assert np.all(np.isfinite(pvals)) + + def test_both_corrections_combined(self, adata_blobs): + """tie_correct=True + use_continuity=True together.""" + adata = adata_blobs.copy() + rsc.get.anndata_to_GPU(adata) + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + tie_correct=True, + use_continuity=True, + ) + + result = adata.uns["rank_genes_groups"] + for group in result["pvals"].dtype.names: + pvals = np.asarray(result["pvals"][group], dtype=float) + assert np.all(pvals >= 0) + assert np.all(pvals <= 1) + + def test_both_corrections_with_reference(self, adata_blobs): + """tie_correct + use_continuity with reference mode.""" + adata = adata_blobs.copy() + rsc.get.anndata_to_GPU(adata) + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + reference="1", + tie_correct=True, + use_continuity=True, + ) + + result = adata.uns["rank_genes_groups"] + assert "1" not in result["names"].dtype.names + for group in result["pvals"].dtype.names: + pvals = np.asarray(result["pvals"][group], dtype=float) + assert np.all(pvals >= 0) + assert np.all(pvals <= 1) + + +# ============================================================================ +# pts (percent expressing) for binned wilcoxon +# ============================================================================ + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_binned_pts(adata_blobs, reference): + """pts computation works with wilcoxon_binned.""" + adata = adata_blobs.copy() + rsc.get.anndata_to_GPU(adata) + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + pts=True, + reference=reference, + ) + + result = adata.uns["rank_genes_groups"] + assert "pts" in result + pts = result["pts"] + assert isinstance(pts, __import__("pandas").DataFrame) + assert all(0 <= v <= 1 for col in pts.columns for v in pts[col]) + + if reference == "rest": + assert "pts_rest" in result + + +# ============================================================================ +# mask_var with string key for binned +# ============================================================================ + + +def test_binned_mask_var_string_key(adata_blobs): + """mask_var accepts a string key from adata.var for binned.""" + adata = adata_blobs.copy() + adata.var["selected"] = [True] * 5 + [False] * 5 + rsc.get.anndata_to_GPU(adata) + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + mask_var="selected", + ) + + result = adata.uns["rank_genes_groups"] + for group in result["names"].dtype.names: + assert len(result["names"][group]) == 5 From cb1a6054d92a1d9a428b0df683f721db557a9c26 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 15 Apr 2026 14:16:46 +0200 Subject: [PATCH 14/21] adress small issues --- src/rapids_singlecell/_cuda/__init__.py | 2 -- .../_cuda/wilcoxon/wilcoxon_common.cuh | 2 +- .../tools/_rank_genes_groups/__init__.py | 8 ++++---- .../tools/_rank_genes_groups/_core.py | 3 --- .../tools/_rank_genes_groups/_utils.py | 10 ---------- .../tools/_rank_genes_groups/_wilcoxon.py | 5 ----- 6 files changed, 5 insertions(+), 25 deletions(-) diff --git a/src/rapids_singlecell/_cuda/__init__.py b/src/rapids_singlecell/_cuda/__init__.py index b11f342a..741178d6 100644 --- a/src/rapids_singlecell/_cuda/__init__.py +++ b/src/rapids_singlecell/_cuda/__init__.py @@ -23,8 +23,6 @@ except (ImportError, OSError): pass -_RMM_MODULES = {"_wilcoxon_ovo_cuda", "_wilcoxon_ovr_cuda"} - __all__ = [ "_aggr_cuda", "_aucell_cuda", diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh index 98d26971..497e98ae 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh @@ -14,7 +14,7 @@ constexpr int WARP_SIZE = 32; constexpr int MAX_THREADS_PER_BLOCK = 512; constexpr int N_STREAMS = 4; -constexpr int SUB_BATCH_COLS = 32; +constexpr int SUB_BATCH_COLS = 64; constexpr int BEGIN_BIT = 0; constexpr int END_BIT = 32; diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index 0b9753a3..6a9aa016 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -104,10 +104,10 @@ def rank_genes_groups( layer Key from `adata.layers` whose value will be used to perform tests on. chunk_size - Number of genes to process at once for `'wilcoxon'` and - `'wilcoxon_binned'`. Default is 128 for `'wilcoxon'`. For - `'wilcoxon_binned'` the default is sized dynamically based on - ``n_groups`` and ``n_bins`` to keep histogram memory stable. + Number of genes to process at once for `'wilcoxon_binned'`. + The default is sized dynamically based on ``n_groups`` and + ``n_bins`` to keep histogram memory stable. + Ignored for other methods. pre_load Pre-load the data into GPU memory. Used only for `'wilcoxon'`. n_bins diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index c65bbf7c..d89a079a 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -324,7 +324,6 @@ def wilcoxon( *, tie_correct: bool, use_continuity: bool = False, - chunk_size: int | None = None, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" from ._wilcoxon import wilcoxon @@ -333,7 +332,6 @@ def wilcoxon( self, tie_correct=tie_correct, use_continuity=use_continuity, - chunk_size=chunk_size, ) def wilcoxon_binned( @@ -394,7 +392,6 @@ def compute_statistics( test_results = self.wilcoxon( tie_correct=tie_correct, use_continuity=use_continuity, - chunk_size=chunk_size, ) elif method == "wilcoxon_binned": test_results = self.wilcoxon_binned( diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index 9dea8f11..890ff6e7 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -102,16 +102,6 @@ def _select_top_n(scores: NDArray, n_top: int) -> NDArray: return global_indices -DEFAULT_CHUNK_SIZE = 512 - - -def _choose_chunk_size(requested: int | None) -> int: - """Choose chunk size for gene processing.""" - if requested is not None: - return int(requested) - return DEFAULT_CHUNK_SIZE - - def _csc_columns_to_gpu(X_csc, start: int, stop: int, n_rows: int) -> cp.ndarray: """ Extract columns from a CSC matrix via direct indptr pointer slicing. diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index cdec0d7b..ac141734 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -193,7 +193,6 @@ def wilcoxon( *, tie_correct: bool, use_continuity: bool = False, - chunk_size: int | None = None, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" X = rg.X @@ -212,7 +211,6 @@ def wilcoxon( group_sizes, tie_correct=tie_correct, use_continuity=use_continuity, - chunk_size=chunk_size, ) return _wilcoxon_vs_rest( rg, @@ -222,7 +220,6 @@ def wilcoxon( group_sizes, tie_correct=tie_correct, use_continuity=use_continuity, - chunk_size=chunk_size, ) @@ -240,7 +237,6 @@ def _wilcoxon_vs_rest( *, tie_correct: bool, use_continuity: bool, - chunk_size: int | None, ) -> list[tuple[int, NDArray, NDArray]]: """Wilcoxon test: each group vs rest of cells. @@ -400,7 +396,6 @@ def _wilcoxon_with_reference( *, tie_correct: bool, use_continuity: bool, - chunk_size: int | None, ) -> list[tuple[int, NDArray, NDArray]]: """Wilcoxon test: each group vs a specific reference group. From 7559778e7859073daf29ed2c5affad98be9e08d4 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 15 Apr 2026 17:51:04 +0200 Subject: [PATCH 15/21] speed up ovr --- .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 241 ++++++ .../_cuda/wilcoxon/wilcoxon_ovr.cu | 766 ++++++++++-------- .../tools/_rank_genes_groups/_wilcoxon.py | 16 +- tests/test_rank_genes_groups_wilcoxon.py | 97 +++ 4 files changed, 764 insertions(+), 356 deletions(-) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index 18cd03cf..b7aa4a40 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -179,3 +179,244 @@ __global__ void rank_sums_from_sorted_kernel( } } } + +/** + * Sparse-aware OVR rank-sum kernel for sorted stored values. + * + * After CUB sort the stored values are in ascending order: + * [negatives..., stored_zeros..., positives...] + * Implicit zeros (n_rows − nnz_stored) are inserted analytically + * between negatives and positives to form the full ranking. + * + * Full sorted array (conceptual): + * [negatives..., ALL_zeros (stored+implicit)..., positives...] + * |<- neg_end ->|<------- total_zero -------->| + * + * Rank offsets: + * negative at stored pos i : full pos = i (no shift) + * positive at stored pos i : full pos = i + n_impl_zero (shift right) + * zeros : avg rank = neg_end + (total_zero+1)/2 + * + * Shared-memory layout (doubles): + * grp_sums[n_groups] rank-sum accumulators + * grp_nz_count[n_groups] nonzero-per-group counters + * warp_buf[32] tie-correction reduction scratch + * + * Grid: (sb_cols,) Block: (tpb,) + */ +__global__ void rank_sums_sparse_ovr_kernel( + const float* __restrict__ sorted_vals, + const int* __restrict__ sorted_row_idx, + const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, const double* __restrict__ group_sizes, + double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_rows, + int sb_cols, int n_groups, bool compute_tie_corr) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + int nnz_stored = seg_end - seg_start; + + const float* sv = sorted_vals + seg_start; + const int* si = sorted_row_idx + seg_start; + + extern __shared__ double smem[]; + double* grp_sums = smem; + double* grp_nz_count = smem + n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; + grp_nz_count[g] = 0.0; + } + __syncthreads(); + + // --- Find zero range: neg_end = first val >= 0, pos_start = first val > 0 + // --- + __shared__ int sh_neg_end; + __shared__ int sh_pos_start; + if (threadIdx.x == 0) { + // Binary search: first index where sv[i] >= 0.0 + int lo = 0, hi = nnz_stored; + while (lo < hi) { + int mid = (lo + hi) >> 1; + if (sv[mid] < 0.0f) + lo = mid + 1; + else + hi = mid; + } + sh_neg_end = lo; + // Binary search: first index where sv[i] > 0.0 + hi = nnz_stored; + while (lo < hi) { + int mid = (lo + hi) >> 1; + if (sv[mid] <= 0.0f) + lo = mid + 1; + else + hi = mid; + } + sh_pos_start = lo; + } + __syncthreads(); + + int neg_end = sh_neg_end; + int pos_start = sh_pos_start; + int n_stored_zero = pos_start - neg_end; + int n_implicit_zero = n_rows - nnz_stored; + int total_zero = n_implicit_zero + n_stored_zero; + double zero_avg_rank = + (total_zero > 0) ? (double)neg_end + (total_zero + 1.0) / 2.0 : 0.0; + + // Rank offset for positive stored values: + // full_pos(i) = i + n_implicit_zero for i >= pos_start + // So avg_rank for tie group [a,b) of positives: + // = n_implicit_zero + (a + b + 1) / 2 + int offset_pos = n_implicit_zero; + + // --- Count stored values != 0.0 per group --- + for (int i = threadIdx.x; i < nnz_stored; i += blockDim.x) { + if (i < neg_end || i >= pos_start) { // skip stored zeros + int grp = group_codes[si[i]]; + if (grp < n_groups) { + atomicAdd(&grp_nz_count[grp], 1.0); + } + } + } + __syncthreads(); + + // --- Zero-rank contribution per group --- + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + double n_zero_in_g = group_sizes[g] - grp_nz_count[g]; + grp_sums[g] = n_zero_in_g * zero_avg_rank; + } + __syncthreads(); + + // --- Walk ALL stored values, skip stored zeros, compute ranks --- + // Chunk over [0, nnz_stored), skip [neg_end, pos_start). + int chunk = (nnz_stored + blockDim.x - 1) / blockDim.x; + int my_start = threadIdx.x * chunk; + int my_end = my_start + chunk; + if (my_end > nnz_stored) my_end = nnz_stored; + + double local_tie_sum = 0.0; + + int i = my_start; + while (i < my_end) { + // Skip stored zeros + if (i >= neg_end && i < pos_start) { + i = pos_start; + continue; + } + + float val = sv[i]; + + int tie_local_end = i + 1; + while (tie_local_end < my_end && sv[tie_local_end] == val) + ++tie_local_end; + // Don't let local tie range cross into stored-zero region + if (val < 0.0f && tie_local_end > neg_end) tie_local_end = neg_end; + + int tie_global_start = i; + if (i == my_start && i > 0 && sv[i - 1] == val) { + // Binary search for first occurrence + int search_lo = (val < 0.0f) ? 0 : pos_start; + int lo = search_lo, hi = i; + while (lo < hi) { + int mid = (lo + hi) >> 1; + if (sv[mid] < val) + lo = mid + 1; + else + hi = mid; + } + tie_global_start = lo; + } + // Handle thread resuming at pos_start after skipping zeros + if (i == pos_start && i > 0 && pos_start > neg_end && + val == sv[pos_start] && i != my_start) { + // Already at pos_start boundary, tie_global_start = pos_start + tie_global_start = pos_start; + } + + int tie_global_end = tie_local_end; + if (tie_local_end == my_end && tie_local_end < nnz_stored && + tie_local_end != neg_end && sv[tie_local_end] == val) { + int search_hi = (val < 0.0f) ? (neg_end - 1) : (nnz_stored - 1); + int lo = tie_local_end, hi = search_hi; + while (lo < hi) { + int mid = (lo + hi + 1) >> 1; + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; + } + tie_global_end = lo + 1; + } + + int total_tie = tie_global_end - tie_global_start; + + // Rank depends on sign: + // negative (i < neg_end): full pos = stored pos (no shift) + // positive (i >= pos_start): full pos = stored pos + n_implicit_zero + double avg_rank; + if (val < 0.0f) { + avg_rank = (double)(tie_global_start + tie_global_end + 1) / 2.0; + } else { + avg_rank = (double)offset_pos + + (double)(tie_global_start + tie_global_end + 1) / 2.0; + } + + for (int j = i; j < tie_local_end; ++j) { + int grp = group_codes[si[j]]; + if (grp < n_groups) { + atomicAdd(&grp_sums[grp], avg_rank); + } + } + + if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { + double t = (double)total_tie; + local_tie_sum += t * t * t - t; + } + + i = tie_local_end; + } + + __syncthreads(); + + // Write rank sums to global output + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * sb_cols + col] = grp_sums[g]; + } + + // Tie correction: warp + block reduction + if (compute_tie_corr) { + // Zero tie group contribution (one thread only) + if (threadIdx.x == 0 && total_zero > 1) { + double tz = (double)total_zero; + local_tie_sum += tz * tz * tz - tz; + } + + int warp_buf_off = 2 * n_groups; + double* warp_buf = smem + warp_buf_off; + +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_tie_sum += __shfl_down_sync(0xffffffff, local_tie_sum, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = local_tie_sum; + __syncthreads(); + if (threadIdx.x < 32) { + double v = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v += __shfl_down_sync(0xffffffff, v, off); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - v / denom) : 1.0; + } + } + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu index be10a487..2f2c2d60 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu @@ -9,6 +9,68 @@ using namespace nb::literals; +/** Rebase a slice of indptr: out[i] = indptr[col + i] - indptr[col]. */ +__global__ void rebase_indptr_kernel(const int* __restrict__ indptr, + int* __restrict__ out, int col, + int count) { + int i = threadIdx.x; + if (i < count) out[i] = indptr[col + i] - indptr[col]; +} + +/** Subtract a constant from an int array in-place. */ +__global__ void subtract_scalar_kernel(int* __restrict__ data, int base, + int count) { + int i = threadIdx.x; + if (i < count) data[i] -= base; +} + +/** Count nonzeros per column from CSR. One thread per row. */ +__global__ void csr_col_histogram_kernel(const int* __restrict__ indices, + const int* __restrict__ indptr, + int* __restrict__ col_counts, + int n_rows, int n_cols) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + int rs = indptr[row]; + int re = indptr[row + 1]; + for (int p = rs; p < re; ++p) { + int c = indices[p]; + if (c < n_cols) atomicAdd(&col_counts[c], 1); + } +} + +/** + * Scatter CSR nonzeros into CSC layout for columns [col_start, col_stop). + * write_pos[c - col_start] must be initialized to the prefix-sum offset + * for column c. Each thread atomically claims a unique destination slot. + */ +__global__ void csr_scatter_to_csc_kernel( + const float* __restrict__ data, const int* __restrict__ indices, + const int* __restrict__ indptr, int* __restrict__ write_pos, + float* __restrict__ csc_vals, int* __restrict__ csc_row_idx, int n_rows, + int col_start, int col_stop) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + int rs = indptr[row]; + int re = indptr[row + 1]; + // Binary search for col_start + int lo = rs, hi = re; + while (lo < hi) { + int m = (lo + hi) >> 1; + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + for (int p = lo; p < re; ++p) { + int c = indices[p]; + if (c >= col_stop) break; + int dest = atomicAdd(&write_pos[c - col_start], 1); + csc_vals[dest] = data[p]; + csc_row_idx[dest] = row; + } +} + /** * Decide whether to use shared or global memory for OVR rank accumulators. * Returns the smem size to request and sets use_gmem accordingly. @@ -31,29 +93,6 @@ static size_t ovr_smem_config(int n_groups, bool& use_gmem) { return 32 * sizeof(double); } -/** - * Extract dense F-order float32 block from CSC. - * Column range [col_start, col_stop). - * One block per column, threads scatter nonzeros. - * Output must be pre-zeroed. - */ -__global__ void csc_extract_f32_kernel(const float* __restrict__ data, - const int* __restrict__ indices, - const int* __restrict__ indptr, - float* __restrict__ out, int n_rows, - int col_start) { - int col_local = blockIdx.x; - int col = col_start + col_local; - - int start = indptr[col]; - int end = indptr[col + 1]; - - for (int p = start + threadIdx.x; p < end; p += blockDim.x) { - int row = indices[p]; - out[(long long)col_local * n_rows + row] = data[p]; - } -} - /** * Fill sort values with row indices [0,1,...,n_rows-1] per column. * Grid: (n_cols,), block: 256 threads. @@ -68,22 +107,6 @@ __global__ void fill_row_indices_kernel(int* __restrict__ vals, int n_rows, } } -/** - * Launch csr_extract_dense_kernel for ALL rows of a CSR matrix. - * Creates a temporary identity row_ids array [0,1,...,n_rows-1]. - */ -static void csr_extract_all_rows(const float* data, const int* indices, - const int* indptr, float* out, int n_rows, - int col_start, int col_stop, RmmPool& pool, - cudaStream_t stream) { - int* row_ids = pool.alloc(n_rows); - fill_row_indices_kernel<<<1, 256, 0, stream>>>(row_ids, n_rows, 1); - int tpb = round_up_to_warp(n_rows); - int blk = (n_rows + tpb - 1) / tpb; - csr_extract_dense_kernel<<>>( - data, indices, indptr, row_ids, out, n_rows, col_start, col_stop); -} - /** * Streaming OVR pipeline. * @@ -210,111 +233,145 @@ static void ovr_streaming_impl(const float* block, const int* group_codes, } /** - * CSR-direct OVR streaming pipeline. + * Sparse-aware host-streaming CSC OVR pipeline. * - * Takes GPU CSR arrays directly — no CSR→CSC conversion needed. - * For each sub-batch: extract dense columns from CSR → sort → rank. - * Everything on one GPU with multi-stream overlap. + * Like ovr_streaming_csc_host_impl but sorts only stored nonzeros per column + * instead of extracting dense blocks. GPU memory is O(max_batch_nnz) instead + * of O(sub_batch * n_rows), and sort work is proportional to nnz, not n_rows. */ -static void ovr_streaming_csr_impl( - const float* csr_data, const int* csr_indices, const int* csr_indptr, - const int* group_codes, double* rank_sums, double* tie_corr, int n_rows, - int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { +static void ovr_sparse_csc_host_streaming_impl( + const float* h_data, const int* h_indices, const int* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* h_rank_sums, + double* h_tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; int n_streams = N_STREAMS; if (n_cols < n_streams * sub_batch_cols) n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - size_t sub_items = (size_t)n_rows * sub_batch_cols; + // Find max nnz across any sub-batch + size_t max_nnz = 0; + for (int col = 0; col < n_cols; col += sub_batch_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); + if (nnz > max_nnz) max_nnz = nnz; + } + + // CUB temp size for max_nnz items size_t cub_temp_bytes = 0; - { + if (max_nnz > 0) { auto* fk = reinterpret_cast(1); auto* iv = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); } std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - // Allocate per-stream buffers via RMM pool RmmPool pool; + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); struct StreamBuf { - float* dense; + float* d_sparse_data; + int* d_sparse_indices; + int* d_seg_offsets; float* keys_out; - int* vals_in; int* vals_out; - int* seg_offsets; uint8_t* cub_temp; - double* sub_rank_sums; - double* sub_tie_corr; + double* d_rank_sums; + double* d_tie_corr; }; std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { - bufs[s].dense = pool.alloc(sub_items); - bufs[s].keys_out = pool.alloc(sub_items); - bufs[s].vals_in = pool.alloc(sub_items); - bufs[s].vals_out = pool.alloc(sub_items); - bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].d_sparse_data = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].keys_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].sub_rank_sums = + bufs[s].d_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); } - int tpb_rank = round_up_to_warp(n_rows); - bool use_gmem = false; - size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + // Transfer group codes + sizes once + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + int tpb = 256; + size_t smem_bytes = (size_t)(2 * n_groups + 32) * sizeof(double); + + // Pin host memory for async transfers + cudaHostRegister(const_cast(h_data), + (size_t)h_indptr[n_cols] * sizeof(float), 0); + cudaHostRegister(const_cast(h_indices), + (size_t)h_indptr[n_cols] * sizeof(int), 0); + cudaHostRegister(const_cast(h_indptr), + (size_t)(n_cols + 1) * sizeof(int), 0); + cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), + 0); + cudaHostRegister(h_tie_corr, n_cols * sizeof(double), 0); + + cudaDeviceSynchronize(); int col = 0; int batch_idx = 0; while (col < n_cols) { int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_items = n_rows * sb_cols; int s = batch_idx % n_streams; auto stream = streams[s]; auto& buf = bufs[s]; - // Zero dense buffer - cudaMemsetAsync(buf.dense, 0, sb_items * sizeof(float), stream); - - // Extract dense columns from CSR (all rows) - csr_extract_all_rows(csr_data, csr_indices, csr_indptr, buf.dense, - n_rows, col, col + sb_cols, pool, stream); - - // Fill segment offsets + row indices - upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); - fill_row_indices_kernel<<>>(buf.vals_in, - n_rows, sb_cols); + int ptr_start = h_indptr[col]; + int ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = ptr_end - ptr_start; - // Sort - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.dense, buf.keys_out, buf.vals_in, - buf.vals_out, sb_items, sb_cols, buf.seg_offsets, - buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + // H2D: transfer sparse data for this column range + if (batch_nnz > 0) { + cudaMemcpyAsync(buf.d_sparse_data, h_data + ptr_start, + (size_t)batch_nnz * sizeof(float), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, + (size_t)batch_nnz * sizeof(int), + cudaMemcpyHostToDevice, stream); + } - // Fused rank sums - if (use_gmem) { - cudaMemsetAsync(buf.sub_rank_sums, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); + // Async transfer indptr slice, then rebase on GPU + cudaMemcpyAsync(buf.d_seg_offsets, h_indptr + col, + (sb_cols + 1) * sizeof(int), cudaMemcpyHostToDevice, + stream); + subtract_scalar_kernel<<<1, sb_cols + 1, 0, stream>>>( + buf.d_seg_offsets, ptr_start, sb_cols + 1); + + // CUB sort only stored nonzeros + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.d_sparse_data, buf.keys_out, + buf.d_sparse_indices, buf.vals_out, batch_nnz, sb_cols, + buf.d_seg_offsets, buf.d_seg_offsets + 1, BEGIN_BIT, END_BIT, + stream); } - rank_sums_from_sorted_kernel<<>>( - buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, - buf.sub_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, - n_groups, compute_tie_corr, false, use_gmem); - // Scatter to global output - cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - buf.sub_rank_sums, sb_cols * sizeof(double), + // Sparse rank kernel + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, + d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, n_rows, sb_cols, + n_groups, compute_tie_corr); + + // D2H: scatter results + cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); + cudaMemcpyDeviceToHost, stream); if (compute_tie_corr) { - cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, - sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + cudaMemcpyAsync(h_tie_corr + col, buf.d_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToHost, stream); } @@ -326,24 +383,29 @@ static void ovr_streaming_csr_impl( cudaError_t err = cudaStreamSynchronize(streams[s]); if (err != cudaSuccess) throw std::runtime_error( - std::string("CUDA error in wilcoxon streaming: ") + + std::string("CUDA error in sparse host CSC streaming: ") + cudaGetErrorString(err)); } + cudaHostUnregister(const_cast(h_data)); + cudaHostUnregister(const_cast(h_indices)); + cudaHostUnregister(const_cast(h_indptr)); + cudaHostUnregister(h_rank_sums); + cudaHostUnregister(h_tie_corr); + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } /** - * CSC-direct OVR streaming pipeline. + * Host-streaming dense OVR pipeline. * - * Takes GPU CSC arrays directly — no format conversion needed. - * For each sub-batch: extract dense columns from CSC → sort → rank. - * CSC extraction is a simple scatter (no binary search), faster than CSR. + * Dense F-order float32 block lives on host. Sub-batches of 64 columns + * are transferred to GPU per stream, so GPU memory is O(sub_batch * n_rows). */ -static void ovr_streaming_csc_impl( - const float* csc_data, const int* csc_indices, const int* csc_indptr, - const int* group_codes, double* rank_sums, double* tie_corr, int n_rows, - int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { +static void ovr_streaming_dense_host_impl( + const float* h_block, const int* h_group_codes, double* h_rank_sums, + double* h_tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; int n_streams = N_STREAMS; @@ -365,33 +427,45 @@ static void ovr_streaming_csc_impl( // Allocate per-stream buffers via RMM pool RmmPool pool; + int* d_group_codes = pool.alloc(n_rows); struct StreamBuf { - float* dense; + float* d_block; float* keys_out; int* vals_in; int* vals_out; int* seg_offsets; uint8_t* cub_temp; - double* sub_rank_sums; - double* sub_tie_corr; + double* d_rank_sums; + double* d_tie_corr; }; std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { - bufs[s].dense = pool.alloc(sub_items); + bufs[s].d_block = pool.alloc(sub_items); bufs[s].keys_out = pool.alloc(sub_items); bufs[s].vals_in = pool.alloc(sub_items); bufs[s].vals_out = pool.alloc(sub_items); bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].sub_rank_sums = + bufs[s].d_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); } + // Group codes on GPU (transferred once) + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + int tpb_rank = round_up_to_warp(n_rows); bool use_gmem = false; size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + // Pin host memory + cudaHostRegister(const_cast(h_block), + (size_t)n_rows * n_cols * sizeof(float), 0); + cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), + 0); + cudaHostRegister(h_tie_corr, n_cols * sizeof(double), 0); + int col = 0; int batch_idx = 0; while (col < n_cols) { @@ -401,12 +475,10 @@ static void ovr_streaming_csc_impl( auto stream = streams[s]; auto& buf = bufs[s]; - // Zero dense buffer - cudaMemsetAsync(buf.dense, 0, sb_items * sizeof(float), stream); - - // Extract dense columns from CSC — simple scatter, no binary search - csc_extract_f32_kernel<<>>( - csc_data, csc_indices, csc_indptr, buf.dense, n_rows, col); + // H2D: column sub-batch (F-order → contiguous) + cudaMemcpyAsync(buf.d_block, h_block + (long long)col * n_rows, + sb_items * sizeof(float), cudaMemcpyHostToDevice, + stream); // Fill segment offsets + row indices upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); @@ -416,29 +488,29 @@ static void ovr_streaming_csc_impl( // Sort size_t temp = cub_temp_bytes; cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.dense, buf.keys_out, buf.vals_in, + buf.cub_temp, temp, buf.d_block, buf.keys_out, buf.vals_in, buf.vals_out, sb_items, sb_cols, buf.seg_offsets, buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); // Fused rank sums if (use_gmem) { - cudaMemsetAsync(buf.sub_rank_sums, 0, + cudaMemsetAsync(buf.d_rank_sums, 0, (size_t)n_groups * sb_cols * sizeof(double), stream); } rank_sums_from_sorted_kernel<<>>( - buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, - buf.sub_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, + buf.keys_out, buf.vals_out, d_group_codes, buf.d_rank_sums, + buf.d_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, n_groups, compute_tie_corr, false, use_gmem); - // Scatter to global output - cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - buf.sub_rank_sums, sb_cols * sizeof(double), + // D2H: scatter results + cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); + cudaMemcpyDeviceToHost, stream); if (compute_tie_corr) { - cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, - sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + cudaMemcpyAsync(h_tie_corr + col, buf.d_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToHost, stream); } @@ -454,38 +526,34 @@ static void ovr_streaming_csc_impl( cudaGetErrorString(err)); } + cudaHostUnregister(const_cast(h_block)); + cudaHostUnregister(h_rank_sums); + cudaHostUnregister(h_tie_corr); + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } -/** - * Host-streaming CSC OVR pipeline. - * - * CSC arrays live on host. Only the sparse data for each sub-batch of - * columns is transferred to GPU, so GPU memory is O(sub_batch * n_rows). - * H2D of sub-batch N+1 overlaps compute of sub-batch N via multi-stream. - */ -static void ovr_streaming_csc_host_impl( - const float* h_data, const int* h_indices, const int* h_indptr, - const int* h_group_codes, double* h_rank_sums, double* h_tie_corr, - int n_rows, int n_cols, int n_groups, bool compute_tie_corr, - int sub_batch_cols) { +// ============================================================================ +// Sparse-aware CSC OVR streaming (sort only stored nonzeros) +// ============================================================================ + +static void ovr_sparse_csc_streaming_impl( + const float* csc_data, const int* csc_indices, const int* csc_indptr, + const int* group_codes, const double* group_sizes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; + // Read indptr to host for batch planning + std::vector h_indptr(n_cols + 1); + cudaMemcpy(h_indptr.data(), csc_indptr, (n_cols + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + int n_streams = N_STREAMS; if (n_cols < n_streams * sub_batch_cols) n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - size_t sub_items = (size_t)n_rows * sub_batch_cols; - size_t cub_temp_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* iv = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, - sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); - } - - // Find max nnz across any sub-batch to size the sparse transfer buffers + // Find max nnz across any sub-batch for buffer sizing size_t max_nnz = 0; for (int col = 0; col < n_cols; col += sub_batch_cols) { int sb_cols = std::min(sub_batch_cols, n_cols - col); @@ -493,126 +561,85 @@ static void ovr_streaming_csc_host_impl( if (nnz > max_nnz) max_nnz = nnz; } + // CUB temp size for max_nnz items + size_t cub_temp_bytes = 0; + if (max_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - // Allocate per-stream buffers via RMM pool RmmPool pool; - int* d_group_codes = pool.alloc(n_rows); struct StreamBuf { - float* d_sparse_data; - int* d_sparse_indices; - int* d_indptr; - float* dense; float* keys_out; - int* vals_in; int* vals_out; int* seg_offsets; uint8_t* cub_temp; - double* d_rank_sums; - double* d_tie_corr; + double* sub_rank_sums; + double* sub_tie_corr; }; std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { - bufs[s].d_sparse_data = pool.alloc(max_nnz); - bufs[s].d_sparse_indices = pool.alloc(max_nnz); - bufs[s].d_indptr = pool.alloc(sub_batch_cols + 1); - bufs[s].dense = pool.alloc(sub_items); - bufs[s].keys_out = pool.alloc(sub_items); - bufs[s].vals_in = pool.alloc(sub_items); - bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].keys_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].d_rank_sums = + bufs[s].sub_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); } - // Group codes on GPU (transferred once) - cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), - cudaMemcpyHostToDevice); - - int tpb_rank = round_up_to_warp(n_rows); - bool use_gmem = false; - size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + int tpb = 256; + size_t smem_bytes = (size_t)(2 * n_groups + 32) * sizeof(double); - // Pin host memory for async transfers - cudaHostRegister(const_cast(h_data), - (size_t)h_indptr[n_cols] * sizeof(float), 0); - cudaHostRegister(const_cast(h_indices), - (size_t)h_indptr[n_cols] * sizeof(int), 0); - cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), - 0); - cudaHostRegister(h_tie_corr, n_cols * sizeof(double), 0); + cudaDeviceSynchronize(); int col = 0; int batch_idx = 0; while (col < n_cols) { int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_items = n_rows * sb_cols; int s = batch_idx % n_streams; auto stream = streams[s]; auto& buf = bufs[s]; - // H2D: transfer sparse data for this column range int ptr_start = h_indptr[col]; int ptr_end = h_indptr[col + sb_cols]; - size_t nnz = (size_t)(ptr_end - ptr_start); - cudaMemcpyAsync(buf.d_sparse_data, h_data + ptr_start, - nnz * sizeof(float), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, - nnz * sizeof(int), cudaMemcpyHostToDevice, stream); - - // Transfer adjusted indptr (rebased to 0) - // h_indptr[col..col+sb_cols] - h_indptr[col] - { - std::vector h_adj(sb_cols + 1); - for (int i = 0; i <= sb_cols; i++) - h_adj[i] = h_indptr[col + i] - ptr_start; - cudaMemcpy(buf.d_indptr, h_adj.data(), (sb_cols + 1) * sizeof(int), - cudaMemcpyHostToDevice); + int batch_nnz = ptr_end - ptr_start; + + // Compute rebased segment offsets on GPU (avoids host pinned-buffer + // race) + rebase_indptr_kernel<<<1, sb_cols + 1, 0, stream>>>( + csc_indptr, buf.seg_offsets, col, sb_cols + 1); + + // Sort only stored values (keys=data, vals=row_indices) + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, csc_data + ptr_start, buf.keys_out, + csc_indices + ptr_start, buf.vals_out, batch_nnz, sb_cols, + buf.seg_offsets, buf.seg_offsets + 1, BEGIN_BIT, END_BIT, + stream); } - // Zero dense buffer - cudaMemsetAsync(buf.dense, 0, sb_items * sizeof(float), stream); - - // CSC extract from transferred sparse data (col_start=0 because - // indptr is rebased and data/indices are for this sub-batch only) - csc_extract_f32_kernel<<>>( - buf.d_sparse_data, buf.d_sparse_indices, buf.d_indptr, buf.dense, - n_rows, 0); - - // Fill segment offsets + row indices - upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); - fill_row_indices_kernel<<>>(buf.vals_in, - n_rows, sb_cols); + // Sparse rank kernel (handles implicit zeros analytically) + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.seg_offsets, group_codes, + group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, n_rows, sb_cols, + n_groups, compute_tie_corr); - // Sort - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.dense, buf.keys_out, buf.vals_in, - buf.vals_out, sb_items, sb_cols, buf.seg_offsets, - buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - - // Fused rank sums - if (use_gmem) { - cudaMemsetAsync(buf.d_rank_sums, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - } - rank_sums_from_sorted_kernel<<>>( - buf.keys_out, buf.vals_out, d_group_codes, buf.d_rank_sums, - buf.d_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, - n_groups, compute_tie_corr, false, use_gmem); - - // D2H: scatter results to host output - cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), - buf.d_rank_sums, sb_cols * sizeof(double), + // Scatter results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToHost, stream); + cudaMemcpyDeviceToDevice, stream); if (compute_tie_corr) { - cudaMemcpyAsync(h_tie_corr + col, buf.d_tie_corr, - sb_cols * sizeof(double), cudaMemcpyDeviceToHost, + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, stream); } @@ -624,152 +651,184 @@ static void ovr_streaming_csc_host_impl( cudaError_t err = cudaStreamSynchronize(streams[s]); if (err != cudaSuccess) throw std::runtime_error( - std::string("CUDA error in wilcoxon streaming: ") + + std::string("CUDA error in sparse ovr streaming: ") + cudaGetErrorString(err)); } - cudaHostUnregister(const_cast(h_data)); - cudaHostUnregister(const_cast(h_indices)); - cudaHostUnregister(h_rank_sums); - cudaHostUnregister(h_tie_corr); - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } +// ============================================================================ +// Sparse-aware CSR OVR streaming (partial CSR→CSC transpose per sub-batch) +// ============================================================================ + /** - * Host-streaming dense OVR pipeline. + * Sparse-aware OVR streaming pipeline for GPU CSR data. * - * Dense F-order float32 block lives on host. Sub-batches of 64 columns - * are transferred to GPU per stream, so GPU memory is O(sub_batch * n_rows). + * Phase 0: One histogram kernel counts nnz per column. D2H + host prefix sums + * give exact per-batch nnz and max_batch_nnz for buffer sizing. + * Phase 1: Allocate per-stream buffers sized to max_batch_nnz. + * Phase 2: For each sub-batch: scatter CSR→CSC (partial transpose via + * atomics) → CUB sort only nonzeros → sparse rank kernel. + * + * Compared to the dense CSR path, sort work drops by ~1/sparsity. */ -static void ovr_streaming_dense_host_impl( - const float* h_block, const int* h_group_codes, double* h_rank_sums, - double* h_tie_corr, int n_rows, int n_cols, int n_groups, +static void ovr_sparse_csr_streaming_impl( + const float* csr_data, const int* csr_indices, const int* csr_indptr, + const int* group_codes, const double* group_sizes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + // ---- Phase 0: Planning — count nnz per column via histogram ---- + RmmPool pool; + int* d_col_counts = pool.alloc(n_cols); + cudaMemset(d_col_counts, 0, n_cols * sizeof(int)); + { + int tpb = 256; + int blocks = (n_rows + tpb - 1) / tpb; + csr_col_histogram_kernel<<>>(csr_indices, csr_indptr, + d_col_counts, n_rows, n_cols); + } + std::vector h_col_counts(n_cols); + cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(int), + cudaMemcpyDeviceToHost); + + // Per-batch prefix sums on host + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + size_t max_batch_nnz = 0; + size_t total_nnz = 0; + + // Flat array: n_batches × (sub_batch_cols + 1) offsets + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + std::vector h_batch_nnz(n_batches); + + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb_cols = std::min(sub_batch_cols, n_cols - col_start); + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + off[0] = 0; + for (int i = 0; i < sb_cols; i++) + off[i + 1] = off[i] + h_col_counts[col_start + i]; + h_batch_nnz[b] = off[sb_cols]; + total_nnz += h_batch_nnz[b]; + if ((size_t)h_batch_nnz[b] > max_batch_nnz) + max_batch_nnz = h_batch_nnz[b]; + } - size_t sub_items = (size_t)n_rows * sub_batch_cols; + // Upload all batch offsets to GPU in one shot (~20 KB) + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // ---- Phase 1: Allocate per-stream buffers ---- size_t cub_temp_bytes = 0; - { + if (max_batch_nnz > 0) { auto* fk = reinterpret_cast(1); auto* iv = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_batch_nnz, sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); } + int n_streams = N_STREAMS; + if (n_batches < n_streams) n_streams = n_batches; + std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - // Allocate per-stream buffers via RMM pool - RmmPool pool; - int* d_group_codes = pool.alloc(n_rows); struct StreamBuf { - float* d_block; - float* keys_out; - int* vals_in; - int* vals_out; - int* seg_offsets; + int* col_offsets; // [sub_batch_cols + 1] CSC-style offsets + int* write_pos; // [sub_batch_cols] atomic write counters + float* csc_vals; // [max_batch_nnz] transposed values + int* csc_row_idx; // [max_batch_nnz] transposed row indices + float* keys_out; // [max_batch_nnz] CUB sort output + int* vals_out; // [max_batch_nnz] CUB sort output uint8_t* cub_temp; - double* d_rank_sums; - double* d_tie_corr; + double* sub_rank_sums; + double* sub_tie_corr; }; std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { - bufs[s].d_block = pool.alloc(sub_items); - bufs[s].keys_out = pool.alloc(sub_items); - bufs[s].vals_in = pool.alloc(sub_items); - bufs[s].vals_out = pool.alloc(sub_items); - bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].col_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].write_pos = pool.alloc(sub_batch_cols); + bufs[s].csc_vals = pool.alloc(max_batch_nnz); + bufs[s].csc_row_idx = pool.alloc(max_batch_nnz); + bufs[s].keys_out = pool.alloc(max_batch_nnz); + bufs[s].vals_out = pool.alloc(max_batch_nnz); bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].d_rank_sums = + bufs[s].sub_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); } - // Group codes on GPU (transferred once) - cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), - cudaMemcpyHostToDevice); - - int tpb_rank = round_up_to_warp(n_rows); - bool use_gmem = false; - size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + int tpb = 256; + size_t smem_bytes = (size_t)(2 * n_groups + 32) * sizeof(double); + int scatter_blocks = (n_rows + tpb - 1) / tpb; - // Pin host memory - cudaHostRegister(const_cast(h_block), - (size_t)n_rows * n_cols * sizeof(float), 0); - cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), - 0); - cudaHostRegister(h_tie_corr, n_cols * sizeof(double), 0); + cudaDeviceSynchronize(); + // ---- Phase 2: Stream loop ---- int col = 0; - int batch_idx = 0; - while (col < n_cols) { + for (int b = 0; b < n_batches; b++) { int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_items = n_rows * sb_cols; - int s = batch_idx % n_streams; + int s = b % n_streams; auto stream = streams[s]; auto& buf = bufs[s]; - - // H2D: column sub-batch (F-order → contiguous) - cudaMemcpyAsync(buf.d_block, h_block + (long long)col * n_rows, - sb_items * sizeof(float), cudaMemcpyHostToDevice, - stream); - - // Fill segment offsets + row indices - upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); - fill_row_indices_kernel<<>>(buf.vals_in, - n_rows, sb_cols); - - // Sort - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.d_block, buf.keys_out, buf.vals_in, - buf.vals_out, sb_items, sb_cols, buf.seg_offsets, - buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - - // Fused rank sums - if (use_gmem) { - cudaMemsetAsync(buf.d_rank_sums, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); + int batch_nnz = h_batch_nnz[b]; + + // D2D copy pre-computed col_offsets for this batch + int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // Initialize write_pos = col_offsets[0..sb_cols-1] (same D2D source) + cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + // Scatter CSR → CSC layout for this sub-batch + csr_scatter_to_csc_kernel<<>>( + csr_data, csr_indices, csr_indptr, buf.write_pos, buf.csc_vals, + buf.csc_row_idx, n_rows, col, col + sb_cols); + + // CUB sort only the nonzeros + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.csc_vals, buf.keys_out, buf.csc_row_idx, + buf.vals_out, batch_nnz, sb_cols, buf.col_offsets, + buf.col_offsets + 1, BEGIN_BIT, END_BIT, stream); } - rank_sums_from_sorted_kernel<<>>( - buf.keys_out, buf.vals_out, d_group_codes, buf.d_rank_sums, - buf.d_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, - n_groups, compute_tie_corr, false, use_gmem); - // D2H: scatter results - cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), - buf.d_rank_sums, sb_cols * sizeof(double), + // Sparse rank kernel (handles implicit zeros analytically) + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.col_offsets, group_codes, + group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, n_rows, sb_cols, + n_groups, compute_tie_corr); + + // Scatter results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToHost, stream); + cudaMemcpyDeviceToDevice, stream); if (compute_tie_corr) { - cudaMemcpyAsync(h_tie_corr + col, buf.d_tie_corr, - sb_cols * sizeof(double), cudaMemcpyDeviceToHost, + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, stream); } col += sb_cols; - batch_idx++; } for (int s = 0; s < n_streams; s++) { cudaError_t err = cudaStreamSynchronize(streams[s]); if (err != cudaSuccess) throw std::runtime_error( - std::string("CUDA error in wilcoxon streaming: ") + + std::string("CUDA error in sparse CSR ovr streaming: ") + cudaGetErrorString(err)); } - cudaHostUnregister(const_cast(h_block)); - cudaHostUnregister(h_rank_sums); - cudaHostUnregister(h_tie_corr); - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -800,41 +859,45 @@ void register_bindings(nb::module_& m) { "sub_batch_cols"_a = SUB_BATCH_COLS); m.def( - "ovr_streaming_csr", + "ovr_sparse_csr", [](gpu_array_c csr_data, gpu_array_c csr_indices, gpu_array_c csr_indptr, gpu_array_c group_codes, + gpu_array_c group_sizes, gpu_array_c rank_sums, gpu_array_c tie_corr, int n_rows, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { - ovr_streaming_csr_impl( + ovr_sparse_csr_streaming_impl( csr_data.data(), csr_indices.data(), csr_indptr.data(), - group_codes.data(), rank_sums.data(), tie_corr.data(), n_rows, - n_cols, n_groups, compute_tie_corr, sub_batch_cols); + group_codes.data(), group_sizes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); }, "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "group_codes"_a, - "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, - "n_groups"_a, "compute_tie_corr"_a, + "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); m.def( - "ovr_streaming_csc", + "ovr_sparse_csc", [](gpu_array_c csc_data, gpu_array_c csc_indices, gpu_array_c csc_indptr, gpu_array_c group_codes, + gpu_array_c group_sizes, gpu_array_c rank_sums, gpu_array_c tie_corr, int n_rows, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { - ovr_streaming_csc_impl( + ovr_sparse_csc_streaming_impl( csc_data.data(), csc_indices.data(), csc_indptr.data(), - group_codes.data(), rank_sums.data(), tie_corr.data(), n_rows, - n_cols, n_groups, compute_tie_corr, sub_batch_cols); + group_codes.data(), group_sizes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); }, "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "group_codes"_a, - "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, - "n_groups"_a, "compute_tie_corr"_a, + "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); } @@ -842,20 +905,21 @@ NB_MODULE(_wilcoxon_ovr_cuda, m) { REGISTER_GPU_BINDINGS(register_bindings, m); m.def( - "ovr_streaming_csc_host", + "ovr_sparse_csc_host", [](host_array h_data, host_array h_indices, host_array h_indptr, host_array h_group_codes, - host_array_2d h_rank_sums, host_array h_tie_corr, - int n_rows, int n_cols, int n_groups, bool compute_tie_corr, - int sub_batch_cols) { - ovr_streaming_csc_host_impl( + host_array h_group_sizes, host_array_2d h_rank_sums, + host_array h_tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + ovr_sparse_csc_host_streaming_impl( h_data.data(), h_indices.data(), h_indptr.data(), - h_group_codes.data(), h_rank_sums.data(), h_tie_corr.data(), - n_rows, n_cols, n_groups, compute_tie_corr, sub_batch_cols); + h_group_codes.data(), h_group_sizes.data(), h_rank_sums.data(), + h_tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); }, "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, - "h_rank_sums"_a, "h_tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, - "n_groups"_a, "compute_tie_corr"_a, + "h_group_sizes"_a, "h_rank_sums"_a, "h_tie_corr"_a, nb::kw_only(), + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); m.def( diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index ac141734..0141ad5f 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -282,11 +282,13 @@ def _wilcoxon_vs_rest( tie_corr_np = np.ones(n_total_genes, dtype=np.float64) if host_csc: - _ovr.ovr_streaming_csc_host( + group_sizes_np = group_sizes.astype(np.float64, copy=False) + _ovr.ovr_sparse_csc_host( X.data.astype(np.float32, copy=False), X.indices.astype(np.int32, copy=False), X.indptr.astype(np.int32, copy=False), group_codes, + group_sizes_np, rank_sums_np, tie_corr_np, n_rows=n_cells, @@ -323,12 +325,15 @@ def _wilcoxon_vs_rest( rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) tie_corr = cp.ones(n_total_genes, dtype=cp.float64) - if cpsp.isspmatrix_csr(X_gpu): - _ovr.ovr_streaming_csr( + if cpsp.isspmatrix_csc(X_gpu): + # Sparse-aware path: sort only stored nonzeros, + # handle zeros analytically. + _ovr.ovr_sparse_csc( X_gpu.data.astype(cp.float32, copy=False), X_gpu.indices.astype(cp.int32, copy=False), X_gpu.indptr.astype(cp.int32, copy=False), group_codes_gpu, + group_sizes_dev, rank_sums, tie_corr, n_rows=n_cells, @@ -337,12 +342,13 @@ def _wilcoxon_vs_rest( compute_tie_corr=tie_correct, sub_batch_cols=STREAMING_SUB_BATCH, ) - elif cpsp.isspmatrix_csc(X_gpu): - _ovr.ovr_streaming_csc( + elif cpsp.isspmatrix_csr(X_gpu): + _ovr.ovr_sparse_csr( X_gpu.data.astype(cp.float32, copy=False), X_gpu.indices.astype(cp.int32, copy=False), X_gpu.indptr.astype(cp.int32, copy=False), group_codes_gpu, + group_sizes_dev, rank_sums, tie_corr, n_rows=n_cells, diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 3977fb35..e194d79f 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -508,6 +508,103 @@ def test_format_matches_scanpy(reference, fmt): ) +# ============================================================================ +# Negative values: centered/scaled data must match scanpy across all formats +# ============================================================================ + + +def _make_centered_adata(n_obs=200, n_vars=8, n_centers=3, seed=42): + """Create AnnData with centered (mean-zero) data containing negatives.""" + np.random.seed(seed) + adata = sc.datasets.blobs( + n_variables=n_vars, n_centers=n_centers, n_observations=n_obs + ) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + # Center each gene to produce negative values + adata.X = adata.X - adata.X.mean(axis=0) + return adata + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize( + "fmt", + [ + pytest.param("scipy_csc", id="scipy_csc"), + pytest.param("cupy_dense", id="cupy_dense"), + pytest.param("cupy_csr", id="cupy_csr"), + pytest.param("cupy_csc", id="cupy_csc"), + ], +) +def test_negative_values_match_scanpy(reference, fmt): + """Centered data (with negatives) matches scanpy across all formats.""" + adata_gpu = _make_centered_adata() + adata_cpu = adata_gpu.copy() + + # Verify data actually has negatives + assert adata_gpu.X.min() < 0 + + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + + for group in gpu_result["names"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + + for field in ("scores", "pvals", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_negative_sparse_matches_dense(reference): + """Sparse and dense paths give identical results for centered data.""" + adata_dense = _make_centered_adata() + adata_csr = adata_dense.copy() + adata_csc = adata_dense.copy() + + adata_csr.X = cpsp.csr_matrix(cp.asarray(adata_dense.X)) + adata_csc.X = cpsp.csc_matrix(cp.asarray(adata_dense.X)) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata_dense, **kw) + rsc.tl.rank_genes_groups(adata_csr, **kw) + rsc.tl.rank_genes_groups(adata_csc, **kw) + + dense_result = adata_dense.uns["rank_genes_groups"] + csr_result = adata_csr.uns["rank_genes_groups"] + csc_result = adata_csc.uns["rank_genes_groups"] + + for field in ("scores", "pvals"): + for group in dense_result[field].dtype.names: + dense_vals = np.asarray(dense_result[field][group], dtype=float) + csr_vals = np.asarray(csr_result[field][group], dtype=float) + csc_vals = np.asarray(csc_result[field][group], dtype=float) + np.testing.assert_allclose(csr_vals, dense_vals, rtol=1e-13, atol=1e-15) + np.testing.assert_allclose(csc_vals, dense_vals, rtol=1e-13, atol=1e-15) + + # ============================================================================ # pre_load: GPU transfer before wilcoxon must match default (lazy transfer) # ============================================================================ From 88bcdd5faee9d7284f0cae38e61fa68f8dbee811 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 15 Apr 2026 18:32:42 +0200 Subject: [PATCH 16/21] remove dead code --- src/rapids_singlecell/_utils/_csr_to_csc.py | 117 -------------------- 1 file changed, 117 deletions(-) delete mode 100644 src/rapids_singlecell/_utils/_csr_to_csc.py diff --git a/src/rapids_singlecell/_utils/_csr_to_csc.py b/src/rapids_singlecell/_utils/_csr_to_csc.py deleted file mode 100644 index 5f4700b0..00000000 --- a/src/rapids_singlecell/_utils/_csr_to_csc.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Fast parallel CSR to CSC conversion using Numba.""" - -from __future__ import annotations - -import numpy as np -import scipy.sparse as sp -from numba import get_num_threads, njit, prange - - -@njit(parallel=True, boundscheck=False) -def _csr_to_csc_kernel(csr_data, csr_indices, csr_indptr, n_cols): - """ - Numba kernel for parallel CSR to CSC conversion. - - Uses a tiled approach with parallel histogram + scatter phases, - targeting ~256 MB per block buffer for L3 cache efficiency. - """ - num_threads = get_num_threads() - nnz = len(csr_data) - n_rows = len(csr_indptr) - 1 - - # Allocate output arrays (int64 for matrices > 2GB) - csc_data = np.empty(nnz, dtype=csr_data.dtype) - csc_indices = np.empty(nnz, dtype=np.int64) - csc_indptr = np.empty(n_cols + 1, dtype=np.int64) - csc_indptr[0] = 0 - - # Block size targeting 256 MB per block buffer (L3 cache target) - TARGET_MEM_BYTES = 256 * 1024 * 1024 - MIN_BLOCK_SIZE = 1000 - block_size = TARGET_MEM_BYTES // (num_threads * 8) - if block_size < MIN_BLOCK_SIZE: - block_size = MIN_BLOCK_SIZE - - # Workspace: threads x block_width - counts = np.zeros((num_threads, block_size), dtype=np.int64) - row_chunk_size = (n_rows + num_threads - 1) // num_threads - - current_global_offset = 0 - - # Tiled execution over column blocks - for col_start in range(0, n_cols, block_size): - col_end = min(col_start + block_size, n_cols) - current_block_width = col_end - col_start - - # 1. Zero counters for this block - counts[:, :current_block_width] = 0 - - # 2. Parallel histogram (count items per column) - for t in prange(num_threads): - r_start = t * row_chunk_size - r_end = min((t + 1) * row_chunk_size, n_rows) - for r in range(r_start, r_end): - for i in range(csr_indptr[r], csr_indptr[r + 1]): - c = csr_indices[i] - if c >= col_start and c < col_end: - counts[t, c - col_start] += 1 - - # 3. Compute write offsets (sequential, fast) - for c in range(current_block_width): - total_in_col = 0 - for t in range(num_threads): - count = counts[t, c] - counts[t, c] = current_global_offset + total_in_col - total_in_col += count - - current_global_offset += total_in_col - csc_indptr[col_start + c + 1] = current_global_offset - - # 4. Parallel scatter (write data) - for t in prange(num_threads): - r_start = t * row_chunk_size - r_end = min((t + 1) * row_chunk_size, n_rows) - for r in range(r_start, r_end): - for i in range(csr_indptr[r], csr_indptr[r + 1]): - c = csr_indices[i] - if c >= col_start and c < col_end: - local_c = c - col_start - dest = counts[t, local_c] - - csc_data[dest] = csr_data[i] - csc_indices[dest] = r - - counts[t, local_c] += 1 - - return csc_data, csc_indices, csc_indptr - - -def _fast_csr_to_csc(mat_csr: sp.csr_matrix) -> sp.csc_matrix: - """ - Convert a SciPy CSR matrix to CSC using parallel Numba kernel. - - Uses a tiled multi-threaded approach for better cache utilization - compared to scipy's default conversion. - - Parameters - ---------- - mat_csr - Input CSR matrix. - - Returns - ------- - CSC matrix with the same data. - """ - if not sp.issparse(mat_csr) or mat_csr.format != "csr": - raise TypeError("Input must be a SciPy CSR matrix") - - rows, cols = mat_csr.shape - - data, indices, indptr = _csr_to_csc_kernel( - mat_csr.data, - mat_csr.indices, - mat_csr.indptr, - cols, - ) - - return sp.csc_matrix((data, indices, indptr), shape=mat_csr.shape) From 4127ac6f150b8d5fbc2271efd23d153aa0fdfab0 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 15 Apr 2026 21:49:44 +0200 Subject: [PATCH 17/21] fix overflows --- .../_cuda/wilcoxon/wilcoxon_common.cuh | 2 +- .../_cuda/wilcoxon/wilcoxon_ovo.cu | 88 +++++++++++++++---- .../_cuda/wilcoxon/wilcoxon_ovr.cu | 83 +++++++++++------ .../tools/_rank_genes_groups/_wilcoxon.py | 27 ++++-- 4 files changed, 148 insertions(+), 52 deletions(-) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh index 497e98ae..4363fd49 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh @@ -72,7 +72,7 @@ __global__ void csr_extract_dense_kernel(const T* __restrict__ data, int lo = rs, hi = re; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (indices[m] < col_start) lo = m + 1; else diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu index bc0d1119..868b3dd2 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu @@ -600,8 +600,9 @@ static void ovo_streaming_csc_impl( * columns is transferred to GPU. Row maps + group offsets are uploaded once. * Results are written back to host per sub-batch. */ +template static void ovo_streaming_csc_host_impl( - const float* h_data, const int* h_indices, const int* h_indptr, + const float* h_data, const int* h_indices, const IndptrT* h_indptr, const int* h_ref_row_map, const int* h_grp_row_map, const int* h_grp_offsets, double* h_rank_sums, double* h_tie_corr, int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, @@ -721,10 +722,9 @@ static void ovo_streaming_csc_host_impl( round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); // Pin host memory for async transfers - cudaHostRegister(const_cast(h_data), - (size_t)h_indptr[n_cols] * sizeof(float), 0); - cudaHostRegister(const_cast(h_indices), - (size_t)h_indptr[n_cols] * sizeof(int), 0); + size_t total_nnz = (size_t)h_indptr[n_cols]; + cudaHostRegister(const_cast(h_data), total_nnz * sizeof(float), 0); + cudaHostRegister(const_cast(h_indices), total_nnz * sizeof(int), 0); cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), 0); cudaHostRegister(h_tie_corr, (size_t)n_groups * n_cols * sizeof(double), 0); @@ -740,8 +740,8 @@ static void ovo_streaming_csc_host_impl( auto& buf = bufs[s]; // ---- H2D: sparse data for this column range ---- - int ptr_start = h_indptr[col]; - int ptr_end = h_indptr[col + sb_cols]; + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; size_t nnz = (size_t)(ptr_end - ptr_start); cudaMemcpyAsync(buf.d_sparse_data, h_data + ptr_start, nnz * sizeof(float), cudaMemcpyHostToDevice, stream); @@ -750,7 +750,7 @@ static void ovo_streaming_csc_host_impl( { std::vector h_adj(sb_cols + 1); for (int i = 0; i <= sb_cols; i++) - h_adj[i] = h_indptr[col + i] - ptr_start; + h_adj[i] = (int)(h_indptr[col + i] - ptr_start); cudaMemcpy(buf.d_indptr, h_adj.data(), (sb_cols + 1) * sizeof(int), cudaMemcpyHostToDevice); } @@ -853,11 +853,12 @@ static void ovo_streaming_csc_host_impl( * The reference is sorted once (not per sub-batch), saving ~50% of the * per-sub-batch extraction + sort work. */ +template static void ovo_streaming_csr_host_impl( - const float* h_data, const int* h_indices, const int* h_indptr, + const float* h_data, const int* h_indices, const IndptrT* h_indptr, const int* h_ref_row_ids, const int* h_grp_row_ids, const int* h_grp_offsets, double* h_rank_sums, double* h_tie_corr, - int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, int nnz, + int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, size_t nnz, bool compute_tie_corr, int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; @@ -921,15 +922,18 @@ static void ovo_streaming_csr_host_impl( int* d_grp_row_ids = pool.alloc(n_all_grp); int* d_grp_offsets = pool.alloc(n_groups + 1); - cudaHostRegister(const_cast(h_data), (size_t)nnz * sizeof(float), - 0); - cudaHostRegister(const_cast(h_indices), (size_t)nnz * sizeof(int), 0); - cudaMemcpyAsync(d_data, h_data, (size_t)nnz * sizeof(float), - cudaMemcpyHostToDevice, streams[0]); - cudaMemcpyAsync(d_indices, h_indices, (size_t)nnz * sizeof(int), + cudaHostRegister(const_cast(h_data), nnz * sizeof(float), 0); + cudaHostRegister(const_cast(h_indices), nnz * sizeof(int), 0); + cudaMemcpyAsync(d_data, h_data, nnz * sizeof(float), cudaMemcpyHostToDevice, + streams[0]); + cudaMemcpyAsync(d_indices, h_indices, nnz * sizeof(int), cudaMemcpyHostToDevice, streams[0]); - cudaMemcpy(d_indptr, h_indptr, (n_rows + 1) * sizeof(int), - cudaMemcpyHostToDevice); + { + std::vector h_indptr32(n_rows + 1); + for (int i = 0; i <= n_rows; i++) h_indptr32[i] = (int)h_indptr[i]; + cudaMemcpy(d_indptr, h_indptr32.data(), (n_rows + 1) * sizeof(int), + cudaMemcpyHostToDevice); + } cudaMemcpy(d_ref_row_ids, h_ref_row_ids, n_ref * sizeof(int), cudaMemcpyHostToDevice); cudaMemcpy(d_grp_row_ids, h_grp_row_ids, n_all_grp * sizeof(int), @@ -1476,6 +1480,29 @@ NB_MODULE(_wilcoxon_ovo_cuda, m) { "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + m.def( + "ovo_streaming_csc_host_i64", + [](host_array h_data, host_array h_indices, + host_array h_indptr, + host_array h_ref_row_map, + host_array h_grp_row_map, + host_array h_grp_offsets, + host_array_2d h_rank_sums, host_array_2d h_tie_corr, + int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + ovo_streaming_csc_host_impl( + h_data.data(), h_indices.data(), h_indptr.data(), + h_ref_row_map.data(), h_grp_row_map.data(), + h_grp_offsets.data(), h_rank_sums.data(), h_tie_corr.data(), + n_ref, n_all_grp, n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_map"_a, + "h_grp_row_map"_a, "h_grp_offsets"_a, "h_rank_sums"_a, "h_tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_rows"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + m.def( "ovo_streaming_csr_host", [](host_array h_data, host_array h_indices, @@ -1484,7 +1511,30 @@ NB_MODULE(_wilcoxon_ovo_cuda, m) { host_array h_grp_offsets, host_array_2d h_rank_sums, host_array_2d h_tie_corr, int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, - int nnz, bool compute_tie_corr, int sub_batch_cols) { + size_t nnz, bool compute_tie_corr, int sub_batch_cols) { + ovo_streaming_csr_host_impl( + h_data.data(), h_indices.data(), h_indptr.data(), + h_ref_row_ids.data(), h_grp_row_ids.data(), + h_grp_offsets.data(), h_rank_sums.data(), h_tie_corr.data(), + n_ref, n_all_grp, n_rows, n_cols, n_groups, nnz, + compute_tie_corr, sub_batch_cols); + }, + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_ids"_a, + "h_grp_row_ids"_a, "h_grp_offsets"_a, "h_rank_sums"_a, "h_tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_rows"_a, "n_cols"_a, + "n_groups"_a, "nnz"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovo_streaming_csr_host_i64", + [](host_array h_data, host_array h_indices, + host_array h_indptr, + host_array h_ref_row_ids, + host_array h_grp_row_ids, + host_array h_grp_offsets, + host_array_2d h_rank_sums, host_array_2d h_tie_corr, + int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, + size_t nnz, bool compute_tie_corr, int sub_batch_cols) { ovo_streaming_csr_host_impl( h_data.data(), h_indices.data(), h_indptr.data(), h_ref_row_ids.data(), h_grp_row_ids.data(), diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu index 2f2c2d60..77cf2d04 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu @@ -53,10 +53,10 @@ __global__ void csr_scatter_to_csc_kernel( if (row >= n_rows) return; int rs = indptr[row]; int re = indptr[row + 1]; - // Binary search for col_start + // Binary search for col_start (overflow-safe midpoint) int lo = rs, hi = re; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (indices[m] < col_start) lo = m + 1; else @@ -239,8 +239,9 @@ static void ovr_streaming_impl(const float* block, const int* group_codes, * instead of extracting dense blocks. GPU memory is O(max_batch_nnz) instead * of O(sub_batch * n_rows), and sort work is proportional to nnz, not n_rows. */ +template static void ovr_sparse_csc_host_streaming_impl( - const float* h_data, const int* h_indices, const int* h_indptr, + const float* h_data, const int* h_indices, const IndptrT* h_indptr, const int* h_group_codes, const double* h_group_sizes, double* h_rank_sums, double* h_tie_corr, int n_rows, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { @@ -307,12 +308,9 @@ static void ovr_sparse_csc_host_streaming_impl( size_t smem_bytes = (size_t)(2 * n_groups + 32) * sizeof(double); // Pin host memory for async transfers - cudaHostRegister(const_cast(h_data), - (size_t)h_indptr[n_cols] * sizeof(float), 0); - cudaHostRegister(const_cast(h_indices), - (size_t)h_indptr[n_cols] * sizeof(int), 0); - cudaHostRegister(const_cast(h_indptr), - (size_t)(n_cols + 1) * sizeof(int), 0); + size_t total_nnz = (size_t)h_indptr[n_cols]; + cudaHostRegister(const_cast(h_data), total_nnz * sizeof(float), 0); + cudaHostRegister(const_cast(h_indices), total_nnz * sizeof(int), 0); cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), 0); cudaHostRegister(h_tie_corr, n_cols * sizeof(double), 0); @@ -327,9 +325,9 @@ static void ovr_sparse_csc_host_streaming_impl( auto stream = streams[s]; auto& buf = bufs[s]; - int ptr_start = h_indptr[col]; - int ptr_end = h_indptr[col + sb_cols]; - int batch_nnz = ptr_end - ptr_start; + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = (int)(ptr_end - ptr_start); // H2D: transfer sparse data for this column range if (batch_nnz > 0) { @@ -341,12 +339,15 @@ static void ovr_sparse_csc_host_streaming_impl( cudaMemcpyHostToDevice, stream); } - // Async transfer indptr slice, then rebase on GPU - cudaMemcpyAsync(buf.d_seg_offsets, h_indptr + col, - (sb_cols + 1) * sizeof(int), cudaMemcpyHostToDevice, - stream); - subtract_scalar_kernel<<<1, sb_cols + 1, 0, stream>>>( - buf.d_seg_offsets, ptr_start, sb_cols + 1); + // Rebase indptr slice on host → int32 per-batch offsets + { + std::vector h_seg(sb_cols + 1); + for (int i = 0; i <= sb_cols; i++) + h_seg[i] = (int)(h_indptr[col + i] - ptr_start); + cudaMemcpyAsync(buf.d_seg_offsets, h_seg.data(), + (sb_cols + 1) * sizeof(int), cudaMemcpyHostToDevice, + stream); + } // CUB sort only stored nonzeros if (batch_nnz > 0) { @@ -389,7 +390,6 @@ static void ovr_sparse_csc_host_streaming_impl( cudaHostUnregister(const_cast(h_data)); cudaHostUnregister(const_cast(h_indices)); - cudaHostUnregister(const_cast(h_indptr)); cudaHostUnregister(h_rank_sums); cudaHostUnregister(h_tie_corr); @@ -697,11 +697,10 @@ static void ovr_sparse_csr_streaming_impl( // Per-batch prefix sums on host int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; size_t max_batch_nnz = 0; - size_t total_nnz = 0; // Flat array: n_batches × (sub_batch_cols + 1) offsets std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); - std::vector h_batch_nnz(n_batches); + std::vector h_batch_nnz(n_batches); for (int b = 0; b < n_batches; b++) { int col_start = b * sub_batch_cols; @@ -710,10 +709,8 @@ static void ovr_sparse_csr_streaming_impl( off[0] = 0; for (int i = 0; i < sb_cols; i++) off[i + 1] = off[i] + h_col_counts[col_start + i]; - h_batch_nnz[b] = off[sb_cols]; - total_nnz += h_batch_nnz[b]; - if ((size_t)h_batch_nnz[b] > max_batch_nnz) - max_batch_nnz = h_batch_nnz[b]; + h_batch_nnz[b] = (size_t)off[sb_cols]; + if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; } // Upload all batch offsets to GPU in one shot (~20 KB) @@ -735,6 +732,21 @@ static void ovr_sparse_csr_streaming_impl( int n_streams = N_STREAMS; if (n_batches < n_streams) n_streams = n_batches; + // CSR path needs 4 sort arrays per stream (scatter intermediates + + // CUB output). Fit stream count to available GPU memory. + size_t per_stream_bytes = + max_batch_nnz * (2 * sizeof(float) + 2 * sizeof(int)) + + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + + (size_t)n_groups * sub_batch_cols * sizeof(double) + + sub_batch_cols * sizeof(double); + + size_t free_mem = 0, total_mem = 0; + cudaMemGetInfo(&free_mem, &total_mem); + constexpr double MEM_BUDGET_FRAC = 0.8; + size_t budget = (size_t)(free_mem * MEM_BUDGET_FRAC); + while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + n_streams--; + std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); @@ -776,7 +788,7 @@ static void ovr_sparse_csr_streaming_impl( int s = b % n_streams; auto stream = streams[s]; auto& buf = bufs[s]; - int batch_nnz = h_batch_nnz[b]; + int batch_nnz = (int)h_batch_nnz[b]; // D2D copy pre-computed col_offsets for this batch int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); @@ -922,6 +934,25 @@ NB_MODULE(_wilcoxon_ovr_cuda, m) { "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + m.def( + "ovr_sparse_csc_host_i64", + [](host_array h_data, host_array h_indices, + host_array h_indptr, + host_array h_group_codes, + host_array h_group_sizes, host_array_2d h_rank_sums, + host_array h_tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + ovr_sparse_csc_host_streaming_impl( + h_data.data(), h_indices.data(), h_indptr.data(), + h_group_codes.data(), h_group_sizes.data(), h_rank_sums.data(), + h_tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, + "h_group_sizes"_a, "h_rank_sums"_a, "h_tie_corr"_a, nb::kw_only(), + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + m.def( "ovr_streaming_dense_host", [](host_array_2d h_block, diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 0141ad5f..7ff9b8db 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -283,10 +283,15 @@ def _wilcoxon_vs_rest( if host_csc: group_sizes_np = group_sizes.astype(np.float64, copy=False) - _ovr.ovr_sparse_csc_host( + _csc_host_fn = ( + _ovr.ovr_sparse_csc_host_i64 + if X.indptr.dtype == np.int64 + else _ovr.ovr_sparse_csc_host + ) + _csc_host_fn( X.data.astype(np.float32, copy=False), X.indices.astype(np.int32, copy=False), - X.indptr.astype(np.int32, copy=False), + X.indptr, group_codes, group_sizes_np, rank_sums_np, @@ -474,10 +479,15 @@ def _wilcoxon_with_reference( tie_corr_np = np.ones((n_test, n_total_genes), dtype=np.float64) if host_sparse and X.format == "csc": - _wc.ovo_streaming_csc_host( + _csc_host_fn = ( + _wc.ovo_streaming_csc_host_i64 + if X.indptr.dtype == np.int64 + else _wc.ovo_streaming_csc_host + ) + _csc_host_fn( X.data.astype(np.float32, copy=False), X.indices.astype(np.int32, copy=False), - X.indptr.astype(np.int32, copy=False), + X.indptr, ref_row_map_np, grp_row_map_np, offsets_np, @@ -493,10 +503,15 @@ def _wilcoxon_with_reference( ) elif host_sparse: csr = X.tocsr() if X.format != "csr" else X - _wc.ovo_streaming_csr_host( + _csr_host_fn = ( + _wc.ovo_streaming_csr_host_i64 + if csr.indptr.dtype == np.int64 + else _wc.ovo_streaming_csr_host + ) + _csr_host_fn( csr.data.astype(np.float32, copy=False), csr.indices.astype(np.int32, copy=False), - csr.indptr.astype(np.int32, copy=False), + csr.indptr, ref_row_ids_np.astype(np.int32, copy=False), all_grp_row_ids_np.astype(np.int32, copy=False), offsets_np, From 8644ba95588ca0004ac126c3e90befcfb35c6650 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 20 Apr 2026 13:21:15 +0200 Subject: [PATCH 18/21] fix streaming --- .../_cuda/cooc/kernels_cooc.cuh | 2 +- .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 162 ++++++- .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 26 +- .../_cuda/wilcoxon/wilcoxon_ovo.cu | 456 +++++++++++------- .../_cuda/wilcoxon/wilcoxon_ovr.cu | 282 +++++++---- .../tools/_rank_genes_groups/_wilcoxon.py | 288 ++++++++--- 6 files changed, 852 insertions(+), 364 deletions(-) diff --git a/src/rapids_singlecell/_cuda/cooc/kernels_cooc.cuh b/src/rapids_singlecell/_cuda/cooc/kernels_cooc.cuh index 4786647d..d5f9b553 100644 --- a/src/rapids_singlecell/_cuda/cooc/kernels_cooc.cuh +++ b/src/rapids_singlecell/_cuda/cooc/kernels_cooc.cuh @@ -353,7 +353,7 @@ __global__ void occur_count_kernel_csr_catpairs_tiled( // Binary search for threshold bin int lo = 0, hi = l_val; while (lo < hi) { - int mid = (lo + hi) >> 1; + int mid = lo + ((hi - lo) >> 1); if (dist_sq <= thresholds[mid]) hi = mid; else diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index b7aa4a40..0c6b7505 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -90,7 +90,7 @@ __global__ void rank_sums_from_sorted_kernel( if (i == my_start && i > 0 && sv[i - 1] == val) { int lo = 0, hi = i; while (lo < hi) { - int mid = (lo + hi) / 2; + int mid = lo + (hi - lo) / 2; if (sv[mid] < val) lo = mid + 1; else @@ -239,7 +239,7 @@ __global__ void rank_sums_sparse_ovr_kernel( // Binary search: first index where sv[i] >= 0.0 int lo = 0, hi = nnz_stored; while (lo < hi) { - int mid = (lo + hi) >> 1; + int mid = lo + ((hi - lo) >> 1); if (sv[mid] < 0.0f) lo = mid + 1; else @@ -249,7 +249,7 @@ __global__ void rank_sums_sparse_ovr_kernel( // Binary search: first index where sv[i] > 0.0 hi = nnz_stored; while (lo < hi) { - int mid = (lo + hi) >> 1; + int mid = lo + ((hi - lo) >> 1); if (sv[mid] <= 0.0f) lo = mid + 1; else @@ -322,7 +322,7 @@ __global__ void rank_sums_sparse_ovr_kernel( int search_lo = (val < 0.0f) ? 0 : pos_start; int lo = search_lo, hi = i; while (lo < hi) { - int mid = (lo + hi) >> 1; + int mid = lo + ((hi - lo) >> 1); if (sv[mid] < val) lo = mid + 1; else @@ -420,3 +420,157 @@ __global__ void rank_sums_sparse_ovr_kernel( } } } + +/** + * Pre-sort cast-and-accumulate kernel for dense OVR host streaming. + * + * Reads a sub-batch block in its native host dtype (InT = float or double), + * writes a float32 copy used as the sort input, and accumulates per-group + * sum, sum-of-squares and nonzero counts in float64. Stats are derived + * from the original-precision values so float64 host input keeps its + * precision while the sort still runs on float32 keys. + * + * Block-per-column layout (grid: (sb_cols,), block: (tpb,)). + * Shared memory: 3 * n_groups doubles (s_sum, s_sq, s_nnz). + */ +template +__global__ void ovr_cast_and_accumulate_dense_kernel( + const InT* __restrict__ block_in, float* __restrict__ block_f32_out, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int n_rows, int sb_cols, int n_groups) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + extern __shared__ double smem[]; + double* s_sum = smem; + double* s_sq = smem + n_groups; + double* s_nnz = smem + 2 * n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + s_sum[g] = 0.0; + s_sq[g] = 0.0; + s_nnz[g] = 0.0; + } + __syncthreads(); + + const InT* src = block_in + (size_t)col * n_rows; + float* dst = block_f32_out + (size_t)col * n_rows; + + for (int r = threadIdx.x; r < n_rows; r += blockDim.x) { + InT v_in = src[r]; + double v = (double)v_in; + dst[r] = (float)v_in; + int g = group_codes[r]; + if (g < n_groups) { + atomicAdd(&s_sum[g], v); + atomicAdd(&s_sq[g], v * v); + if (v != 0.0) atomicAdd(&s_nnz[g], 1.0); + } + } + __syncthreads(); + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + group_sums[(size_t)g * sb_cols + col] = s_sum[g]; + group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; + group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; + } +} + +/** + * One-shot cast-and-accumulate kernel for CSR-layout host streaming. + * + * The OVO CSR host path uploads the full CSR once; this kernel walks the + * uploaded data row-by-row, writes a float32 copy of the values, and + * accumulates per-group sum/sum-sq/nnz directly into a full-size + * (n_groups_stats, n_cols) output using global atomics. stats_codes[row] + * must be in [0, n_groups_stats) to contribute; other values (e.g. the + * sentinel for unselected cells) are skipped. + * + * Grid: (ceil(n_rows/tpb),), Block: (tpb,). + */ +template +__global__ void cast_and_accumulate_csr_kernel( + const InT* __restrict__ data_in, float* __restrict__ data_f32_out, + const int* __restrict__ indices, const int* __restrict__ indptr, + const int* __restrict__ stats_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int n_rows, int n_cols, int n_groups_stats) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + int slot = stats_codes[row]; + int rs = indptr[row]; + int re = indptr[row + 1]; + bool accumulate = (slot >= 0 && slot < n_groups_stats); + for (int p = rs; p < re; p++) { + InT v_in = data_in[p]; + double v = (double)v_in; + data_f32_out[p] = (float)v_in; + if (accumulate) { + int c = indices[p]; + atomicAdd(&group_sums[(size_t)slot * n_cols + c], v); + atomicAdd(&group_sq_sums[(size_t)slot * n_cols + c], v * v); + if (v != 0.0) { + atomicAdd(&group_nnz[(size_t)slot * n_cols + c], 1.0); + } + } + } +} + +/** + * Pre-sort cast-and-accumulate kernel for sparse OVR host streaming. + * + * Sub-batch CSC data is laid out contiguously: values for column c live + * at positions [col_seg_offsets[c], col_seg_offsets[c+1]). For each + * stored value, read the native-dtype InT, write a float32 copy for the + * CUB sort, and accumulate per-group sum/sum-sq/nnz in float64. Implicit + * zeros contribute nothing to any of these stats. + * + * Block-per-column layout (grid: (sb_cols,), block: (tpb,)). + * Shared memory: 3 * n_groups doubles. + */ +template +__global__ void ovr_cast_and_accumulate_sparse_kernel( + const InT* __restrict__ data_in, float* __restrict__ data_f32_out, + const int* __restrict__ indices, const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int sb_cols, int n_groups) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + + extern __shared__ double smem[]; + double* s_sum = smem; + double* s_sq = smem + n_groups; + double* s_nnz = smem + 2 * n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + s_sum[g] = 0.0; + s_sq[g] = 0.0; + s_nnz[g] = 0.0; + } + __syncthreads(); + + for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { + InT v_in = data_in[i]; + double v = (double)v_in; + data_f32_out[i] = (float)v_in; + int row = indices[i]; + int g = group_codes[row]; + if (g < n_groups) { + atomicAdd(&s_sum[g], v); + atomicAdd(&s_sq[g], v * v); + if (v != 0.0) atomicAdd(&s_nnz[g], 1.0); + } + } + __syncthreads(); + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + group_sums[(size_t)g * sb_cols + col] = s_sum[g]; + group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; + group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh index fac8816a..ea7586b0 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -56,7 +56,7 @@ __device__ __forceinline__ void compute_tie_correction_parallel( // Count in ref: upper_bound from i+1 int lo = i + 1, hi = n_ref; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (ref_col[m] <= v) lo = m + 1; else @@ -68,7 +68,7 @@ __device__ __forceinline__ void compute_tie_correction_parallel( lo = grp_lb; hi = n_grp; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (grp_col[m] < v) lo = m + 1; else @@ -80,7 +80,7 @@ __device__ __forceinline__ void compute_tie_correction_parallel( lo = (grp_ub > lb) ? grp_ub : lb; hi = n_grp; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (grp_col[m] <= v) lo = m + 1; else @@ -106,7 +106,7 @@ __device__ __forceinline__ void compute_tie_correction_parallel( // Incremental lower_bound in ref int lo = ref_lb, hi = n_ref; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (ref_col[m] < v) lo = m + 1; else @@ -119,7 +119,7 @@ __device__ __forceinline__ void compute_tie_correction_parallel( lo = i + 1; hi = n_grp; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (grp_col[m] <= v) lo = m + 1; else @@ -190,7 +190,7 @@ __global__ void batched_rank_sums_presorted_kernel( lo = ref_lb; hi = n_ref; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (ref_col[m] < v) lo = m + 1; else @@ -203,7 +203,7 @@ __global__ void batched_rank_sums_presorted_kernel( lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; hi = n_ref; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (ref_col[m] <= v) lo = m + 1; else @@ -216,7 +216,7 @@ __global__ void batched_rank_sums_presorted_kernel( lo = grp_lb; hi = n_grp; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (grp_col[m] < v) lo = m + 1; else @@ -229,7 +229,7 @@ __global__ void batched_rank_sums_presorted_kernel( lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; hi = n_grp; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (grp_col[m] <= v) lo = m + 1; else @@ -330,7 +330,7 @@ __global__ void ovo_fused_sort_rank_kernel( lo = ref_lb; hi = n_ref; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (ref_col[m] < v) lo = m + 1; else @@ -342,7 +342,7 @@ __global__ void ovo_fused_sort_rank_kernel( lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; hi = n_ref; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (ref_col[m] <= v) lo = m + 1; else @@ -354,7 +354,7 @@ __global__ void ovo_fused_sort_rank_kernel( lo = grp_lb; hi = n_grp; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (grp_smem[m] < v) lo = m + 1; else @@ -366,7 +366,7 @@ __global__ void ovo_fused_sort_rank_kernel( lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; hi = n_grp; while (lo < hi) { - int m = (lo + hi) >> 1; + int m = lo + ((hi - lo) >> 1); if (grp_smem[m] <= v) lo = m + 1; else diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu index 868b3dd2..a2a513cf 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu @@ -5,6 +5,7 @@ #include "../nb_types.h" #include "wilcoxon_common.cuh" +#include "kernels_wilcoxon.cuh" #include "kernels_wilcoxon_ovo.cuh" using namespace nb::literals; @@ -600,13 +601,15 @@ static void ovo_streaming_csc_impl( * columns is transferred to GPU. Row maps + group offsets are uploaded once. * Results are written back to host per sub-batch. */ -template +template static void ovo_streaming_csc_host_impl( - const float* h_data, const int* h_indices, const IndptrT* h_indptr, + const InT* h_data, const int* h_indices, const IndptrT* h_indptr, const int* h_ref_row_map, const int* h_grp_row_map, - const int* h_grp_offsets, double* h_rank_sums, double* h_tie_corr, - int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, int sub_batch_cols) { + const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_ref, int n_all_grp, int n_rows, int n_cols, + int n_groups, int n_groups_stats, bool compute_tie_corr, + int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; // ---- Tier dispatch from host offsets ---- @@ -669,19 +672,23 @@ static void ovo_streaming_csc_host_impl( RmmPool pool; - // GPU copies of row maps + group offsets (uploaded once) + // GPU copies of row maps + group offsets + stats codes (uploaded once) int* d_ref_row_map = pool.alloc(n_rows); int* d_grp_row_map = pool.alloc(n_rows); int* d_grp_offsets = pool.alloc(n_groups + 1); + int* d_stats_codes = pool.alloc(n_rows); cudaMemcpy(d_ref_row_map, h_ref_row_map, n_rows * sizeof(int), cudaMemcpyHostToDevice); cudaMemcpy(d_grp_row_map, h_grp_row_map, n_rows * sizeof(int), cudaMemcpyHostToDevice); cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), cudaMemcpyHostToDevice); + cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); struct StreamBuf { - float* d_sparse_data; + InT* d_sparse_data_orig; + float* d_sparse_data_f32; int* d_sparse_indices; int* d_indptr; float* ref_dense; @@ -693,10 +700,14 @@ static void ovo_streaming_csc_host_impl( uint8_t* cub_temp; double* d_rank_sums; double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; }; std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { - bufs[s].d_sparse_data = pool.alloc(max_nnz); + bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); + bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); bufs[s].d_sparse_indices = pool.alloc(max_nnz); bufs[s].d_indptr = pool.alloc(sub_batch_cols + 1); bufs[s].ref_dense = pool.alloc(sub_ref_items); @@ -708,6 +719,12 @@ static void ovo_streaming_csc_host_impl( pool.alloc((size_t)n_groups * sub_batch_cols); bufs[s].d_tie_corr = pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups_stats * sub_batch_cols); + bufs[s].d_group_sq_sums = + pool.alloc((size_t)n_groups_stats * sub_batch_cols); + bufs[s].d_group_nnz = + pool.alloc((size_t)n_groups_stats * sub_batch_cols); if (!use_tier1) { bufs[s].grp_sorted = pool.alloc(sub_grp_items); int max_grp_seg = n_groups * sub_batch_cols; @@ -720,14 +737,12 @@ static void ovo_streaming_csc_host_impl( int tpb_rank = round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + size_t smem_cast = (size_t)(3 * n_groups_stats) * sizeof(double); - // Pin host memory for async transfers + // Pin only the sparse input arrays; outputs live on the device. size_t total_nnz = (size_t)h_indptr[n_cols]; - cudaHostRegister(const_cast(h_data), total_nnz * sizeof(float), 0); + cudaHostRegister(const_cast(h_data), total_nnz * sizeof(InT), 0); cudaHostRegister(const_cast(h_indices), total_nnz * sizeof(int), 0); - cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), - 0); - cudaHostRegister(h_tie_corr, (size_t)n_groups * n_cols * sizeof(double), 0); int col = 0; int batch_idx = 0; @@ -739,12 +754,12 @@ static void ovo_streaming_csc_host_impl( auto stream = streams[s]; auto& buf = bufs[s]; - // ---- H2D: sparse data for this column range ---- + // ---- H2D: sparse data for this column range (native dtype) ---- IndptrT ptr_start = h_indptr[col]; IndptrT ptr_end = h_indptr[col + sb_cols]; size_t nnz = (size_t)(ptr_end - ptr_start); - cudaMemcpyAsync(buf.d_sparse_data, h_data + ptr_start, - nnz * sizeof(float), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, + nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, nnz * sizeof(int), cudaMemcpyHostToDevice, stream); { @@ -755,11 +770,19 @@ static void ovo_streaming_csc_host_impl( cudaMemcpyHostToDevice); } + // ---- Cast to float32 for sort + accumulate stats in float64 ---- + ovr_cast_and_accumulate_sparse_kernel + <<>>( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, + buf.d_sparse_indices, buf.d_indptr, d_stats_codes, + buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, + n_groups_stats); + // ---- Extract ref from CSC via row_map, sort ---- cudaMemsetAsync(buf.ref_dense, 0, sb_ref_actual * sizeof(float), stream); csc_extract_mapped_kernel<<>>( - buf.d_sparse_data, buf.d_sparse_indices, buf.d_indptr, + buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, d_ref_row_map, buf.ref_dense, n_ref, 0); upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); { @@ -774,7 +797,7 @@ static void ovo_streaming_csc_host_impl( cudaMemsetAsync(buf.grp_dense, 0, sb_grp_actual * sizeof(float), stream); csc_extract_mapped_kernel<<>>( - buf.d_sparse_data, buf.d_sparse_indices, buf.d_indptr, + buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, d_grp_row_map, buf.grp_dense, n_all_grp, 0); // ---- Tier dispatch: sort grp + rank ---- @@ -810,17 +833,29 @@ static void ovo_streaming_csc_host_impl( } } - // ---- D2H: scatter results ---- - cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), + // ---- D2D: scatter sub-batch results into caller's GPU buffers ---- + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), buf.d_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToHost, stream); + cudaMemcpyDeviceToDevice, stream); if (compute_tie_corr) { - cudaMemcpy2DAsync(h_tie_corr + col, n_cols * sizeof(double), + cudaMemcpy2DAsync(d_tie_corr + col, n_cols * sizeof(double), buf.d_tie_corr, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToHost, stream); + cudaMemcpyDeviceToDevice, stream); } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); col += sb_cols; batch_idx++; @@ -834,10 +869,8 @@ static void ovo_streaming_csc_host_impl( cudaGetErrorString(err)); } - cudaHostUnregister(const_cast(h_data)); + cudaHostUnregister(const_cast(h_data)); cudaHostUnregister(const_cast(h_indices)); - cudaHostUnregister(h_rank_sums); - cudaHostUnregister(h_tie_corr); for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -853,13 +886,15 @@ static void ovo_streaming_csc_host_impl( * The reference is sorted once (not per sub-batch), saving ~50% of the * per-sub-batch extraction + sort work. */ -template +template static void ovo_streaming_csr_host_impl( - const float* h_data, const int* h_indices, const IndptrT* h_indptr, + const InT* h_data, const int* h_indices, const IndptrT* h_indptr, const int* h_ref_row_ids, const int* h_grp_row_ids, - const int* h_grp_offsets, double* h_rank_sums, double* h_tie_corr, - int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, size_t nnz, - bool compute_tie_corr, int sub_batch_cols) { + const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_ref, int n_all_grp, int n_rows, int n_cols, + int n_groups, int n_groups_stats, size_t nnz, bool compute_tie_corr, + int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; // ---- Tier dispatch from host offsets ---- @@ -914,18 +949,20 @@ static void ovo_streaming_csr_host_impl( RmmPool pool; - // ---- Phase 1: Transfer CSR, extract + sort reference (all columns) ---- + // ---- Phase 1: Transfer CSR (native dtype), cast + accumulate stats ---- + InT* d_data_orig = pool.alloc(nnz); float* d_data = pool.alloc(nnz); int* d_indices = pool.alloc(nnz); int* d_indptr = pool.alloc(n_rows + 1); int* d_ref_row_ids = pool.alloc(n_ref); int* d_grp_row_ids = pool.alloc(n_all_grp); int* d_grp_offsets = pool.alloc(n_groups + 1); + int* d_stats_codes = pool.alloc(n_rows); - cudaHostRegister(const_cast(h_data), nnz * sizeof(float), 0); + cudaHostRegister(const_cast(h_data), nnz * sizeof(InT), 0); cudaHostRegister(const_cast(h_indices), nnz * sizeof(int), 0); - cudaMemcpyAsync(d_data, h_data, nnz * sizeof(float), cudaMemcpyHostToDevice, - streams[0]); + cudaMemcpyAsync(d_data_orig, h_data, nnz * sizeof(InT), + cudaMemcpyHostToDevice, streams[0]); cudaMemcpyAsync(d_indices, h_indices, nnz * sizeof(int), cudaMemcpyHostToDevice, streams[0]); { @@ -940,6 +977,27 @@ static void ovo_streaming_csr_host_impl( cudaMemcpyHostToDevice); cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), cudaMemcpyHostToDevice); + cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + + // Zero caller's stats buffers before the atomicAdd-based kernel. + cudaMemsetAsync(d_group_sums, 0, + (size_t)n_groups_stats * n_cols * sizeof(double), + streams[0]); + cudaMemsetAsync(d_group_sq_sums, 0, + (size_t)n_groups_stats * n_cols * sizeof(double), + streams[0]); + cudaMemsetAsync(d_group_nnz, 0, + (size_t)n_groups_stats * n_cols * sizeof(double), + streams[0]); + { + int tpb = 256; + int blk = (n_rows + tpb - 1) / tpb; + cast_and_accumulate_csr_kernel<<>>( + d_data_orig, d_data, d_indices, d_indptr, d_stats_codes, + d_group_sums, d_group_sq_sums, d_group_nnz, n_rows, n_cols, + n_groups_stats); + } cudaStreamSynchronize(streams[0]); // Extract ref for ALL columns, sort once @@ -995,10 +1053,6 @@ static void ovo_streaming_csr_host_impl( int tpb_rank = round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); - cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), - 0); - cudaHostRegister(h_tie_corr, (size_t)n_groups * n_cols * sizeof(double), 0); - int col = 0; int batch_idx = 0; while (col < n_cols) { @@ -1052,16 +1106,16 @@ static void ovo_streaming_csr_host_impl( } } - // D2H results - cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), + // D2D: scatter rank_sums / tie_corr into caller's GPU buffers + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), buf.d_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToHost, stream); + cudaMemcpyDeviceToDevice, stream); if (compute_tie_corr) { - cudaMemcpy2DAsync(h_tie_corr + col, n_cols * sizeof(double), + cudaMemcpy2DAsync(d_tie_corr + col, n_cols * sizeof(double), buf.d_tie_corr, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToHost, stream); + cudaMemcpyDeviceToDevice, stream); } col += sb_cols; @@ -1076,10 +1130,8 @@ static void ovo_streaming_csr_host_impl( cudaGetErrorString(err)); } - cudaHostUnregister(const_cast(h_data)); + cudaHostUnregister(const_cast(h_data)); cudaHostUnregister(const_cast(h_indices)); - cudaHostUnregister(h_rank_sums); - cudaHostUnregister(h_tie_corr); for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -1109,11 +1161,14 @@ __global__ void dense_gather_rows_kernel(const float* __restrict__ in, * transferred, then ref/grp rows are gathered, sorted, and ranked. * Results D2H per sub-batch. */ +template static void ovo_streaming_dense_host_impl( - const float* h_block, const int* h_ref_row_ids, const int* h_grp_row_ids, - const int* h_grp_offsets, double* h_rank_sums, double* h_tie_corr, - int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, int sub_batch_cols) { + const InT* h_block, const int* h_ref_row_ids, const int* h_grp_row_ids, + const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_ref, int n_all_grp, int n_rows, int n_cols, + int n_groups, int n_groups_stats, bool compute_tie_corr, + int sub_batch_cols) { if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; // ---- Tier dispatch from host offsets ---- @@ -1169,19 +1224,23 @@ static void ovo_streaming_dense_host_impl( RmmPool pool; - // GPU copies of row_ids + group offsets (uploaded once) + // GPU copies of row_ids + group offsets + stats codes (uploaded once) int* d_ref_row_ids = pool.alloc(n_ref); int* d_grp_row_ids = pool.alloc(n_all_grp); int* d_grp_offsets = pool.alloc(n_groups + 1); + int* d_stats_codes = pool.alloc(n_rows); cudaMemcpy(d_ref_row_ids, h_ref_row_ids, n_ref * sizeof(int), cudaMemcpyHostToDevice); cudaMemcpy(d_grp_row_ids, h_grp_row_ids, n_all_grp * sizeof(int), cudaMemcpyHostToDevice); cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), cudaMemcpyHostToDevice); + cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); struct StreamBuf { - float* d_block; // H2D sub-batch (all rows) + InT* d_block_orig; + float* d_block_f32; float* ref_dense; float* ref_sorted; float* grp_dense; @@ -1191,10 +1250,14 @@ static void ovo_streaming_dense_host_impl( uint8_t* cub_temp; double* d_rank_sums; double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; }; std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { - bufs[s].d_block = pool.alloc(sub_dense); + bufs[s].d_block_orig = pool.alloc(sub_dense); + bufs[s].d_block_f32 = pool.alloc(sub_dense); bufs[s].ref_dense = pool.alloc(sub_ref_items); bufs[s].ref_sorted = pool.alloc(sub_ref_items); bufs[s].grp_dense = pool.alloc(sub_grp_items); @@ -1204,6 +1267,12 @@ static void ovo_streaming_dense_host_impl( pool.alloc((size_t)n_groups * sub_batch_cols); bufs[s].d_tie_corr = pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups_stats * sub_batch_cols); + bufs[s].d_group_sq_sums = + pool.alloc((size_t)n_groups_stats * sub_batch_cols); + bufs[s].d_group_nnz = + pool.alloc((size_t)n_groups_stats * sub_batch_cols); if (!use_tier1) { bufs[s].grp_sorted = pool.alloc(sub_grp_items); int max_grp_seg = n_groups * sub_batch_cols; @@ -1216,13 +1285,11 @@ static void ovo_streaming_dense_host_impl( int tpb_rank = round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + size_t smem_cast = (size_t)(3 * n_groups_stats) * sizeof(double); - // Pin host memory - cudaHostRegister(const_cast(h_block), - (size_t)n_rows * n_cols * sizeof(float), 0); - cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), - 0); - cudaHostRegister(h_tie_corr, (size_t)n_groups * n_cols * sizeof(double), 0); + // Pin only the host input; outputs live on the device. + cudaHostRegister(const_cast(h_block), + (size_t)n_rows * n_cols * sizeof(InT), 0); int col = 0; int batch_idx = 0; @@ -1235,14 +1302,21 @@ static void ovo_streaming_dense_host_impl( auto stream = streams[s]; auto& buf = bufs[s]; - // ---- H2D: dense column sub-batch (F-order, contiguous) ---- - cudaMemcpyAsync(buf.d_block, h_block + (long long)col * n_rows, - sb_dense * sizeof(float), cudaMemcpyHostToDevice, - stream); + // ---- H2D: dense column sub-batch (F-order, native dtype) ---- + cudaMemcpyAsync(buf.d_block_orig, h_block + (long long)col * n_rows, + sb_dense * sizeof(InT), cudaMemcpyHostToDevice, stream); + + // ---- Cast to float32 for sort + accumulate stats in float64 ---- + ovr_cast_and_accumulate_dense_kernel + <<>>( + buf.d_block_orig, buf.d_block_f32, d_stats_codes, + buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, n_rows, + sb_cols, n_groups_stats); // ---- Gather ref rows, sort ---- dense_gather_rows_kernel<<>>( - buf.d_block, d_ref_row_ids, buf.ref_dense, n_rows, n_ref, sb_cols); + buf.d_block_f32, d_ref_row_ids, buf.ref_dense, n_rows, n_ref, + sb_cols); upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); { size_t temp = cub_temp_bytes; @@ -1254,7 +1328,7 @@ static void ovo_streaming_dense_host_impl( // ---- Gather grp rows ---- dense_gather_rows_kernel<<>>( - buf.d_block, d_grp_row_ids, buf.grp_dense, n_rows, n_all_grp, + buf.d_block_f32, d_grp_row_ids, buf.grp_dense, n_rows, n_all_grp, sb_cols); // ---- Tier dispatch: sort grp + rank ---- @@ -1290,17 +1364,29 @@ static void ovo_streaming_dense_host_impl( } } - // ---- D2H: scatter results ---- - cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), + // ---- D2D: scatter sub-batch results into caller's GPU buffers ---- + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), buf.d_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToHost, stream); + cudaMemcpyDeviceToDevice, stream); if (compute_tie_corr) { - cudaMemcpy2DAsync(h_tie_corr + col, n_cols * sizeof(double), + cudaMemcpy2DAsync(d_tie_corr + col, n_cols * sizeof(double), buf.d_tie_corr, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToHost, stream); + cudaMemcpyDeviceToDevice, stream); } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); col += sb_cols; batch_idx++; @@ -1314,9 +1400,7 @@ static void ovo_streaming_dense_host_impl( cudaGetErrorString(err)); } - cudaHostUnregister(const_cast(h_block)); - cudaHostUnregister(h_rank_sums); - cudaHostUnregister(h_tie_corr); + cudaHostUnregister(const_cast(h_block)); for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -1453,118 +1537,122 @@ void register_bindings(nb::module_& m) { "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + + // ---- Host-streaming pipelines (host inputs, device outputs) ---- + +#define RSC_OVO_CSC_HOST_BINDING(NAME, InT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_ref_row_map, \ + host_array h_grp_row_map, \ + host_array h_grp_offsets, \ + host_array h_stats_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ + int n_rows, int n_cols, int n_groups, int n_groups_stats, \ + bool compute_tie_corr, int sub_batch_cols) { \ + ovo_streaming_csc_host_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_ref_row_map.data(), h_grp_row_map.data(), \ + h_grp_offsets.data(), h_stats_codes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_ref, n_all_grp, \ + n_rows, n_cols, n_groups, n_groups_stats, compute_tie_corr, \ + sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_map"_a, \ + "h_grp_row_map"_a, "h_grp_offsets"_a, "h_stats_codes"_a, \ + "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_ref"_a, \ + "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, \ + "n_groups_stats"_a, "compute_tie_corr"_a, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", float, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_i64", float, int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64", double, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_i64", double, int64_t); +#undef RSC_OVO_CSC_HOST_BINDING + +#define RSC_OVO_CSR_HOST_BINDING(NAME, InT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_ref_row_ids, \ + host_array h_grp_row_ids, \ + host_array h_grp_offsets, \ + host_array h_stats_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ + int n_rows, int n_cols, int n_groups, int n_groups_stats, \ + size_t nnz, bool compute_tie_corr, int sub_batch_cols) { \ + ovo_streaming_csr_host_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_ref_row_ids.data(), h_grp_row_ids.data(), \ + h_grp_offsets.data(), h_stats_codes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_ref, n_all_grp, \ + n_rows, n_cols, n_groups, n_groups_stats, nnz, \ + compute_tie_corr, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_ids"_a, \ + "h_grp_row_ids"_a, "h_grp_offsets"_a, "h_stats_codes"_a, \ + "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_ref"_a, \ + "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, \ + "n_groups_stats"_a, "nnz"_a, "compute_tie_corr"_a, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", float, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_i64", float, int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64", double, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_i64", double, int64_t); +#undef RSC_OVO_CSR_HOST_BINDING + +#define RSC_OVO_DENSE_HOST_BINDING(NAME, InT) \ + m.def( \ + NAME, \ + [](host_array_2d h_block, \ + host_array h_ref_row_ids, \ + host_array h_grp_row_ids, \ + host_array h_grp_offsets, \ + host_array h_stats_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ + int n_rows, int n_cols, int n_groups, int n_groups_stats, \ + bool compute_tie_corr, int sub_batch_cols) { \ + ovo_streaming_dense_host_impl( \ + h_block.data(), h_ref_row_ids.data(), h_grp_row_ids.data(), \ + h_grp_offsets.data(), h_stats_codes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_ref, n_all_grp, \ + n_rows, n_cols, n_groups, n_groups_stats, compute_tie_corr, \ + sub_batch_cols); \ + }, \ + "h_block"_a, "h_ref_row_ids"_a, "h_grp_row_ids"_a, "h_grp_offsets"_a, \ + "h_stats_codes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_ref"_a, \ + "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, \ + "n_groups_stats"_a, "compute_tie_corr"_a, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_DENSE_HOST_BINDING("ovo_streaming_dense_host", float); + RSC_OVO_DENSE_HOST_BINDING("ovo_streaming_dense_host_f64", double); +#undef RSC_OVO_DENSE_HOST_BINDING } NB_MODULE(_wilcoxon_ovo_cuda, m) { REGISTER_GPU_BINDINGS(register_bindings, m); - - m.def( - "ovo_streaming_csc_host", - [](host_array h_data, host_array h_indices, - host_array h_indptr, host_array h_ref_row_map, - host_array h_grp_row_map, - host_array h_grp_offsets, - host_array_2d h_rank_sums, host_array_2d h_tie_corr, - int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, int sub_batch_cols) { - ovo_streaming_csc_host_impl( - h_data.data(), h_indices.data(), h_indptr.data(), - h_ref_row_map.data(), h_grp_row_map.data(), - h_grp_offsets.data(), h_rank_sums.data(), h_tie_corr.data(), - n_ref, n_all_grp, n_rows, n_cols, n_groups, compute_tie_corr, - sub_batch_cols); - }, - "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_map"_a, - "h_grp_row_map"_a, "h_grp_offsets"_a, "h_rank_sums"_a, "h_tie_corr"_a, - nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_rows"_a, "n_cols"_a, - "n_groups"_a, "compute_tie_corr"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); - - m.def( - "ovo_streaming_csc_host_i64", - [](host_array h_data, host_array h_indices, - host_array h_indptr, - host_array h_ref_row_map, - host_array h_grp_row_map, - host_array h_grp_offsets, - host_array_2d h_rank_sums, host_array_2d h_tie_corr, - int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, int sub_batch_cols) { - ovo_streaming_csc_host_impl( - h_data.data(), h_indices.data(), h_indptr.data(), - h_ref_row_map.data(), h_grp_row_map.data(), - h_grp_offsets.data(), h_rank_sums.data(), h_tie_corr.data(), - n_ref, n_all_grp, n_rows, n_cols, n_groups, compute_tie_corr, - sub_batch_cols); - }, - "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_map"_a, - "h_grp_row_map"_a, "h_grp_offsets"_a, "h_rank_sums"_a, "h_tie_corr"_a, - nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_rows"_a, "n_cols"_a, - "n_groups"_a, "compute_tie_corr"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); - - m.def( - "ovo_streaming_csr_host", - [](host_array h_data, host_array h_indices, - host_array h_indptr, host_array h_ref_row_ids, - host_array h_grp_row_ids, - host_array h_grp_offsets, - host_array_2d h_rank_sums, host_array_2d h_tie_corr, - int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, - size_t nnz, bool compute_tie_corr, int sub_batch_cols) { - ovo_streaming_csr_host_impl( - h_data.data(), h_indices.data(), h_indptr.data(), - h_ref_row_ids.data(), h_grp_row_ids.data(), - h_grp_offsets.data(), h_rank_sums.data(), h_tie_corr.data(), - n_ref, n_all_grp, n_rows, n_cols, n_groups, nnz, - compute_tie_corr, sub_batch_cols); - }, - "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_ids"_a, - "h_grp_row_ids"_a, "h_grp_offsets"_a, "h_rank_sums"_a, "h_tie_corr"_a, - nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_rows"_a, "n_cols"_a, - "n_groups"_a, "nnz"_a, "compute_tie_corr"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); - - m.def( - "ovo_streaming_csr_host_i64", - [](host_array h_data, host_array h_indices, - host_array h_indptr, - host_array h_ref_row_ids, - host_array h_grp_row_ids, - host_array h_grp_offsets, - host_array_2d h_rank_sums, host_array_2d h_tie_corr, - int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, - size_t nnz, bool compute_tie_corr, int sub_batch_cols) { - ovo_streaming_csr_host_impl( - h_data.data(), h_indices.data(), h_indptr.data(), - h_ref_row_ids.data(), h_grp_row_ids.data(), - h_grp_offsets.data(), h_rank_sums.data(), h_tie_corr.data(), - n_ref, n_all_grp, n_rows, n_cols, n_groups, nnz, - compute_tie_corr, sub_batch_cols); - }, - "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_ids"_a, - "h_grp_row_ids"_a, "h_grp_offsets"_a, "h_rank_sums"_a, "h_tie_corr"_a, - nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_rows"_a, "n_cols"_a, - "n_groups"_a, "nnz"_a, "compute_tie_corr"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); - - m.def( - "ovo_streaming_dense_host", - [](host_array_2d h_block, - host_array h_ref_row_ids, - host_array h_grp_row_ids, - host_array h_grp_offsets, - host_array_2d h_rank_sums, host_array_2d h_tie_corr, - int n_ref, int n_all_grp, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, int sub_batch_cols) { - ovo_streaming_dense_host_impl( - h_block.data(), h_ref_row_ids.data(), h_grp_row_ids.data(), - h_grp_offsets.data(), h_rank_sums.data(), h_tie_corr.data(), - n_ref, n_all_grp, n_rows, n_cols, n_groups, compute_tie_corr, - sub_batch_cols); - }, - "h_block"_a, "h_ref_row_ids"_a, "h_grp_row_ids"_a, "h_grp_offsets"_a, - "h_rank_sums"_a, "h_tie_corr"_a, nb::kw_only(), "n_ref"_a, - "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, - "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu index 77cf2d04..58fda6d0 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu @@ -239,11 +239,12 @@ static void ovr_streaming_impl(const float* block, const int* group_codes, * instead of extracting dense blocks. GPU memory is O(max_batch_nnz) instead * of O(sub_batch * n_rows), and sort work is proportional to nnz, not n_rows. */ -template +template static void ovr_sparse_csc_host_streaming_impl( - const float* h_data, const int* h_indices, const IndptrT* h_indptr, - const int* h_group_codes, const double* h_group_sizes, double* h_rank_sums, - double* h_tie_corr, int n_rows, int n_cols, int n_groups, + const InT* h_data, const int* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_rows, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; @@ -276,7 +277,8 @@ static void ovr_sparse_csc_host_streaming_impl( int* d_group_codes = pool.alloc(n_rows); double* d_group_sizes = pool.alloc(n_groups); struct StreamBuf { - float* d_sparse_data; + InT* d_sparse_data_orig; + float* d_sparse_data_f32; int* d_sparse_indices; int* d_seg_offsets; float* keys_out; @@ -284,10 +286,14 @@ static void ovr_sparse_csc_host_streaming_impl( uint8_t* cub_temp; double* d_rank_sums; double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; }; std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { - bufs[s].d_sparse_data = pool.alloc(max_nnz); + bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); + bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); bufs[s].d_sparse_indices = pool.alloc(max_nnz); bufs[s].d_seg_offsets = pool.alloc(sub_batch_cols + 1); bufs[s].keys_out = pool.alloc(max_nnz); @@ -296,6 +302,12 @@ static void ovr_sparse_csc_host_streaming_impl( bufs[s].d_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sq_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_nnz = + pool.alloc((size_t)n_groups * sub_batch_cols); } // Transfer group codes + sizes once @@ -306,14 +318,12 @@ static void ovr_sparse_csc_host_streaming_impl( int tpb = 256; size_t smem_bytes = (size_t)(2 * n_groups + 32) * sizeof(double); + size_t smem_cast = (size_t)(3 * n_groups) * sizeof(double); - // Pin host memory for async transfers + // Pin only the host input arrays; outputs live on the device. size_t total_nnz = (size_t)h_indptr[n_cols]; - cudaHostRegister(const_cast(h_data), total_nnz * sizeof(float), 0); + cudaHostRegister(const_cast(h_data), total_nnz * sizeof(InT), 0); cudaHostRegister(const_cast(h_indices), total_nnz * sizeof(int), 0); - cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), - 0); - cudaHostRegister(h_tie_corr, n_cols * sizeof(double), 0); cudaDeviceSynchronize(); @@ -329,10 +339,10 @@ static void ovr_sparse_csc_host_streaming_impl( IndptrT ptr_end = h_indptr[col + sb_cols]; int batch_nnz = (int)(ptr_end - ptr_start); - // H2D: transfer sparse data for this column range + // H2D: transfer sparse data for this column range (native dtype) if (batch_nnz > 0) { - cudaMemcpyAsync(buf.d_sparse_data, h_data + ptr_start, - (size_t)batch_nnz * sizeof(float), + cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, + (size_t)batch_nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, (size_t)batch_nnz * sizeof(int), @@ -349,32 +359,52 @@ static void ovr_sparse_csc_host_streaming_impl( stream); } - // CUB sort only stored nonzeros + // Cast to float32 for sort + accumulate stats in float64 + ovr_cast_and_accumulate_sparse_kernel + <<>>( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, + buf.d_sparse_indices, buf.d_seg_offsets, d_group_codes, + buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, + n_groups); + + // CUB sort only stored nonzeros (float32 keys) if (batch_nnz > 0) { size_t temp = cub_temp_bytes; cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.d_sparse_data, buf.keys_out, + buf.cub_temp, temp, buf.d_sparse_data_f32, buf.keys_out, buf.d_sparse_indices, buf.vals_out, batch_nnz, sb_cols, buf.d_seg_offsets, buf.d_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); } - // Sparse rank kernel + // Sparse rank kernel (stats already captured above) rank_sums_sparse_ovr_kernel<<>>( buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr); - // D2H: scatter results - cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), + // D2D: scatter sub-batch results into caller's GPU buffers + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), buf.d_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToHost, stream); + cudaMemcpyDeviceToDevice, stream); if (compute_tie_corr) { - cudaMemcpyAsync(h_tie_corr + col, buf.d_tie_corr, - sb_cols * sizeof(double), cudaMemcpyDeviceToHost, + cudaMemcpyAsync(d_tie_corr + col, buf.d_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, stream); } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); col += sb_cols; batch_idx++; @@ -388,10 +418,8 @@ static void ovr_sparse_csc_host_streaming_impl( cudaGetErrorString(err)); } - cudaHostUnregister(const_cast(h_data)); + cudaHostUnregister(const_cast(h_data)); cudaHostUnregister(const_cast(h_indices)); - cudaHostUnregister(h_rank_sums); - cudaHostUnregister(h_tie_corr); for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -399,12 +427,24 @@ static void ovr_sparse_csc_host_streaming_impl( /** * Host-streaming dense OVR pipeline. * - * Dense F-order float32 block lives on host. Sub-batches of 64 columns - * are transferred to GPU per stream, so GPU memory is O(sub_batch * n_rows). + * Templated on the host dtype (InT = float or double). Each sub-batch is + * copied to the device in its native dtype once; a fused cast+accumulate + * kernel writes a float32 view for the sort and accumulates per-group + * sum/sum-sq/nnz in float64 from the original-precision values. The + * existing sort + rank pipeline then runs on the float32 keys. + * + * Output pointers ({d_rank_sums, d_tie_corr, d_group_sums, d_group_sq_sums, + * d_group_nnz}) point to caller-provided CuPy memory of the full output + * shape; sub-batch kernels scatter directly into them via D2D. + * + * GPU memory stays at O(sub_batch * n_rows), now with a small extra + * InT-sized sub-batch buffer per stream. */ +template static void ovr_streaming_dense_host_impl( - const float* h_block, const int* h_group_codes, double* h_rank_sums, - double* h_tie_corr, int n_rows, int n_cols, int n_groups, + const InT* h_block, const int* h_group_codes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_rows, int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; @@ -429,7 +469,8 @@ static void ovr_streaming_dense_host_impl( RmmPool pool; int* d_group_codes = pool.alloc(n_rows); struct StreamBuf { - float* d_block; + InT* d_block_orig; + float* d_block_f32; float* keys_out; int* vals_in; int* vals_out; @@ -437,10 +478,14 @@ static void ovr_streaming_dense_host_impl( uint8_t* cub_temp; double* d_rank_sums; double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; }; std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { - bufs[s].d_block = pool.alloc(sub_items); + bufs[s].d_block_orig = pool.alloc(sub_items); + bufs[s].d_block_f32 = pool.alloc(sub_items); bufs[s].keys_out = pool.alloc(sub_items); bufs[s].vals_in = pool.alloc(sub_items); bufs[s].vals_out = pool.alloc(sub_items); @@ -449,6 +494,12 @@ static void ovr_streaming_dense_host_impl( bufs[s].d_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sq_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_nnz = + pool.alloc((size_t)n_groups * sub_batch_cols); } // Group codes on GPU (transferred once) @@ -458,13 +509,12 @@ static void ovr_streaming_dense_host_impl( int tpb_rank = round_up_to_warp(n_rows); bool use_gmem = false; size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + int tpb_cast = 256; + size_t smem_cast = (size_t)(3 * n_groups) * sizeof(double); - // Pin host memory - cudaHostRegister(const_cast(h_block), - (size_t)n_rows * n_cols * sizeof(float), 0); - cudaHostRegister(h_rank_sums, (size_t)n_groups * n_cols * sizeof(double), - 0); - cudaHostRegister(h_tie_corr, n_cols * sizeof(double), 0); + // Pin only the host input. Outputs live on the device (caller-owned). + cudaHostRegister(const_cast(h_block), + (size_t)n_rows * n_cols * sizeof(InT), 0); int col = 0; int batch_idx = 0; @@ -475,10 +525,16 @@ static void ovr_streaming_dense_host_impl( auto stream = streams[s]; auto& buf = bufs[s]; - // H2D: column sub-batch (F-order → contiguous) - cudaMemcpyAsync(buf.d_block, h_block + (long long)col * n_rows, - sb_items * sizeof(float), cudaMemcpyHostToDevice, - stream); + // H2D: column sub-batch in native dtype (F-order → contiguous) + cudaMemcpyAsync(buf.d_block_orig, h_block + (long long)col * n_rows, + sb_items * sizeof(InT), cudaMemcpyHostToDevice, stream); + + // Cast to float32 for sort + accumulate stats in float64 + ovr_cast_and_accumulate_dense_kernel + <<>>( + buf.d_block_orig, buf.d_block_f32, d_group_codes, + buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, n_rows, + sb_cols, n_groups); // Fill segment offsets + row indices upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); @@ -488,11 +544,11 @@ static void ovr_streaming_dense_host_impl( // Sort size_t temp = cub_temp_bytes; cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.d_block, buf.keys_out, buf.vals_in, + buf.cub_temp, temp, buf.d_block_f32, buf.keys_out, buf.vals_in, buf.vals_out, sb_items, sb_cols, buf.seg_offsets, buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - // Fused rank sums + // Fused rank sums (stats already captured by the cast kernel) if (use_gmem) { cudaMemsetAsync(buf.d_rank_sums, 0, (size_t)n_groups * sb_cols * sizeof(double), @@ -503,16 +559,28 @@ static void ovr_streaming_dense_host_impl( buf.d_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, n_groups, compute_tie_corr, false, use_gmem); - // D2H: scatter results - cudaMemcpy2DAsync(h_rank_sums + col, n_cols * sizeof(double), + // D2D: scatter sub-batch results into the caller's GPU buffers + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), buf.d_rank_sums, sb_cols * sizeof(double), sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToHost, stream); + cudaMemcpyDeviceToDevice, stream); if (compute_tie_corr) { - cudaMemcpyAsync(h_tie_corr + col, buf.d_tie_corr, - sb_cols * sizeof(double), cudaMemcpyDeviceToHost, + cudaMemcpyAsync(d_tie_corr + col, buf.d_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, stream); } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); col += sb_cols; batch_idx++; @@ -526,9 +594,7 @@ static void ovr_streaming_dense_host_impl( cudaGetErrorString(err)); } - cudaHostUnregister(const_cast(h_block)); - cudaHostUnregister(h_rank_sums); - cudaHostUnregister(h_tie_corr); + cudaHostUnregister(const_cast(h_block)); for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -911,61 +977,69 @@ void register_bindings(nb::module_& m) { "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + + // ---- Host-streaming pipelines (host inputs, device outputs) ---- + +#define RSC_OVR_SPARSE_CSC_HOST_BINDING(NAME, InT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + host_array h_group_sizes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, int sub_batch_cols) { \ + ovr_sparse_csc_host_streaming_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), h_group_sizes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ + n_groups, compute_tie_corr, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, \ + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64", float, int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64", double, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, + int64_t); +#undef RSC_OVR_SPARSE_CSC_HOST_BINDING + +#define RSC_OVR_DENSE_HOST_BINDING(NAME, InT) \ + m.def( \ + NAME, \ + [](host_array_2d h_block, \ + host_array h_group_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, int sub_batch_cols) { \ + ovr_streaming_dense_host_impl( \ + h_block.data(), h_group_codes.data(), d_rank_sums.data(), \ + d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ + n_groups, compute_tie_corr, sub_batch_cols); \ + }, \ + "h_block"_a, "h_group_codes"_a, "d_rank_sums"_a, "d_tie_corr"_a, \ + "d_group_sums"_a, "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), \ + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_DENSE_HOST_BINDING("ovr_streaming_dense_host", float); + RSC_OVR_DENSE_HOST_BINDING("ovr_streaming_dense_host_f64", double); +#undef RSC_OVR_DENSE_HOST_BINDING } NB_MODULE(_wilcoxon_ovr_cuda, m) { REGISTER_GPU_BINDINGS(register_bindings, m); - - m.def( - "ovr_sparse_csc_host", - [](host_array h_data, host_array h_indices, - host_array h_indptr, host_array h_group_codes, - host_array h_group_sizes, host_array_2d h_rank_sums, - host_array h_tie_corr, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, int sub_batch_cols) { - ovr_sparse_csc_host_streaming_impl( - h_data.data(), h_indices.data(), h_indptr.data(), - h_group_codes.data(), h_group_sizes.data(), h_rank_sums.data(), - h_tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, - sub_batch_cols); - }, - "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, - "h_group_sizes"_a, "h_rank_sums"_a, "h_tie_corr"_a, nb::kw_only(), - "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); - - m.def( - "ovr_sparse_csc_host_i64", - [](host_array h_data, host_array h_indices, - host_array h_indptr, - host_array h_group_codes, - host_array h_group_sizes, host_array_2d h_rank_sums, - host_array h_tie_corr, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, int sub_batch_cols) { - ovr_sparse_csc_host_streaming_impl( - h_data.data(), h_indices.data(), h_indptr.data(), - h_group_codes.data(), h_group_sizes.data(), h_rank_sums.data(), - h_tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, - sub_batch_cols); - }, - "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, - "h_group_sizes"_a, "h_rank_sums"_a, "h_tie_corr"_a, nb::kw_only(), - "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); - - m.def( - "ovr_streaming_dense_host", - [](host_array_2d h_block, - host_array h_group_codes, - host_array_2d h_rank_sums, host_array h_tie_corr, - int n_rows, int n_cols, int n_groups, bool compute_tie_corr, - int sub_batch_cols) { - ovr_streaming_dense_host_impl(h_block.data(), h_group_codes.data(), - h_rank_sums.data(), h_tie_corr.data(), - n_rows, n_cols, n_groups, - compute_tie_corr, sub_batch_cols); - }, - "h_block"_a, "h_group_codes"_a, "h_rank_sums"_a, "h_tie_corr"_a, - nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, - "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); } diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 7ff9b8db..52d1fcb6 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -25,6 +25,102 @@ # --------------------------------------------------------------------------- +def _fill_basic_stats_from_accumulators( + rg: _RankGenes, + group_sums: cp.ndarray, + group_sq_sums: cp.ndarray, + group_nnz: cp.ndarray, + group_sizes: np.ndarray, + *, + n_cells: int, +) -> None: + """Populate rg.means/vars/pts (+ *_rest) from streamed accumulators. + + Mirrors the Aggregate-based path in :meth:`_RankGenes._basic_stats` + but consumes per-group sums/sum-of-squares/nnz that the host-streaming + kernels write directly into caller-provided CuPy buffers. The math + runs on GPU and only the derived (means/vars/pts) arrays are + transferred to host, so the full matrix never round-trips. + """ + n = cp.asarray(group_sizes, dtype=cp.float64)[:, None] + means = group_sums / n + group_ss = group_sq_sums - n * means**2 + vars_ = cp.maximum(group_ss / cp.maximum(n - 1, 1), 0) + + rg.means = cp.asnumpy(means) + rg.vars = cp.asnumpy(vars_) + rg.pts = cp.asnumpy(group_nnz / n) if rg.comp_pts else None + + if rg.ireference is None: + n_rest = cp.float64(n_cells) - n + total_sum = group_sums.sum(axis=0, keepdims=True) + total_sq_sum = group_sq_sums.sum(axis=0, keepdims=True) + rest_sums = total_sum - group_sums + rest_means = rest_sums / n_rest + rest_ss = (total_sq_sum - group_sq_sums) - n_rest * rest_means**2 + rg.means_rest = cp.asnumpy(rest_means) + rg.vars_rest = cp.asnumpy(cp.maximum(rest_ss / cp.maximum(n_rest - 1, 1), 0)) + if rg.comp_pts: + total_nnz = group_nnz.sum(axis=0, keepdims=True) + rg.pts_rest = cp.asnumpy((total_nnz - group_nnz) / n_rest) + else: + rg.pts_rest = None + else: + rg.means_rest = None + rg.vars_rest = None + rg.pts_rest = None + + rg._compute_stats_in_chunks = False + + +def _fill_ovo_stats_from_accumulators( + rg: _RankGenes, + group_sums_slots: cp.ndarray, + group_sq_sums_slots: cp.ndarray, + group_nnz_slots: cp.ndarray, + *, + group_sizes: NDArray, + test_group_indices: list[int], + n_ref: int, +) -> None: + """Populate rg.means/vars/pts from OVO stats slots. + + Slot ordering: 0..n_test-1 are test groups (in ``test_group_indices`` + order); slot n_test is the reference group. Stats arrays arrive as + CuPy buffers; the math runs on GPU and only the per-group rows are + transferred to host as they're assigned onto rg. + """ + n_test = len(test_group_indices) + n_genes = int(group_sums_slots.shape[1]) + n_groups = len(rg.groups_order) + + rg.means = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.vars = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.pts = np.zeros((n_groups, n_genes), dtype=np.float64) if rg.comp_pts else None + + def _fill(slot: int, size: int, gi: int) -> None: + if size <= 0: + return + sums = group_sums_slots[slot] + sq = group_sq_sums_slots[slot] + mean = sums / size + rg.means[gi] = cp.asnumpy(mean) + if size > 1: + ss = sq - size * mean**2 + rg.vars[gi] = cp.asnumpy(cp.maximum(ss / max(size - 1, 1), 0)) + if rg.comp_pts: + rg.pts[gi] = cp.asnumpy(group_nnz_slots[slot] / size) + + for i, gi in enumerate(test_group_indices): + _fill(i, int(group_sizes[gi]), gi) + _fill(n_test, int(n_ref), rg.ireference) + + rg.means_rest = None + rg.vars_rest = None + rg.pts_rest = None + rg._compute_stats_in_chunks = False + + def _to_gpu_native(X, n_rows: int, n_cols: int): """Move *X* to GPU, preserving its format (CSR / CSC / dense).""" # Already on GPU @@ -266,36 +362,44 @@ def _wilcoxon_vs_rest( host_dense = isinstance(X, np.ndarray) if host_csc or host_dense: - # Host-streaming: sort+rank stays on host→GPU per sub-batch. - # Stats still need Aggregate on GPU — cheap one-time transfer. - # _basic_stats was already called by wilcoxon() which set - # _compute_stats_in_chunks=True for host data. Transfer a - # lightweight GPU copy just for Aggregate, then discard it. - if rg._compute_stats_in_chunks: - X_gpu_tmp = _to_gpu_native(X, n_cells, n_total_genes) - rg.X = X_gpu_tmp - rg._compute_stats_in_chunks = False - rg._basic_stats() - del X_gpu_tmp - - rank_sums_np = np.empty((n_groups, n_total_genes), dtype=np.float64) - tie_corr_np = np.ones(n_total_genes, dtype=np.float64) + # Host-streaming: sort+rank stays on host→GPU per sub-batch. The + # kernel also emits per-group sum, sum-of-squares, and nonzero + # counts into caller-provided CuPy buffers, so means/vars/pts can + # be derived without uploading the full matrix. Outputs live on + # the GPU and feed directly into the z-score / p-value math below. + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + group_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + group_sq_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + group_nnz = cp.empty((n_groups, n_total_genes), dtype=cp.float64) if host_csc: group_sizes_np = group_sizes.astype(np.float64, copy=False) - _csc_host_fn = ( - _ovr.ovr_sparse_csc_host_i64 - if X.indptr.dtype == np.int64 - else _ovr.ovr_sparse_csc_host - ) + # Native host dtype is preserved and uploaded once per sub-batch; + # a pre-sort kernel casts to float32 for the sort keys while + # accumulating stats in float64 from the original values. + is_f64 = X.data.dtype == np.float64 + is_i64 = X.indptr.dtype == np.int64 + if is_f64 and is_i64: + _csc_host_fn = _ovr.ovr_sparse_csc_host_f64_i64 + elif is_f64: + _csc_host_fn = _ovr.ovr_sparse_csc_host_f64 + elif is_i64: + _csc_host_fn = _ovr.ovr_sparse_csc_host_i64 + else: + _csc_host_fn = _ovr.ovr_sparse_csc_host + data_arr = X.data if is_f64 else X.data.astype(np.float32, copy=False) _csc_host_fn( - X.data.astype(np.float32, copy=False), + data_arr, X.indices.astype(np.int32, copy=False), X.indptr, group_codes, group_sizes_np, - rank_sums_np, - tie_corr_np, + rank_sums, + tie_corr, + group_sums, + group_sq_sums, + group_nnz, n_rows=n_cells, n_cols=n_total_genes, n_groups=n_groups, @@ -303,11 +407,21 @@ def _wilcoxon_vs_rest( sub_batch_cols=STREAMING_SUB_BATCH, ) else: - _ovr.ovr_streaming_dense_host( - np.asfortranarray(X.astype(np.float32, copy=False)), + is_f64 = X.dtype == np.float64 + _dense_host_fn = ( + _ovr.ovr_streaming_dense_host_f64 + if is_f64 + else _ovr.ovr_streaming_dense_host + ) + block = X if is_f64 else X.astype(np.float32, copy=False) + _dense_host_fn( + np.asfortranarray(block), group_codes, - rank_sums_np, - tie_corr_np, + rank_sums, + tie_corr, + group_sums, + group_sq_sums, + group_nnz, n_rows=n_cells, n_cols=n_total_genes, n_groups=n_groups, @@ -315,8 +429,15 @@ def _wilcoxon_vs_rest( sub_batch_cols=STREAMING_SUB_BATCH, ) - rank_sums = cp.asarray(rank_sums_np) - tie_corr = cp.asarray(tie_corr_np) + if rg._compute_stats_in_chunks: + _fill_basic_stats_from_accumulators( + rg, + group_sums, + group_sq_sums, + group_nnz, + group_sizes.astype(np.float64, copy=False), + n_cells=n_cells, + ) else: # GPU data or host CSR → transfer to GPU, use GPU kernels X_gpu = _to_gpu_native(X, n_cells, n_total_genes) @@ -468,83 +589,134 @@ def _wilcoxon_with_reference( host_sparse = isinstance(X, sp.spmatrix | sp.sparray) host_dense = isinstance(X, np.ndarray) if host_sparse or host_dense: - if rg._compute_stats_in_chunks: - X_gpu_tmp = _to_gpu_native(X, n_cells, n_total_genes) - rg.X = X_gpu_tmp - rg._compute_stats_in_chunks = False - rg._basic_stats() - del X_gpu_tmp + # Output buffers live on the GPU (caller-provided CuPy memory); + # kernels write directly into them, and rank_sums / tie_corr + # feed the z-score math below without any H2D → H2D round-trip. + rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) + + # Stats slots: 0..n_test-1 = test groups, slot n_test = reference. + # Unselected cells carry the sentinel (n_groups_stats) which the + # kernel skips. + n_groups_stats = n_test + 1 + stats_codes_np = np.full(n_cells, n_groups_stats, dtype=np.int32) + for i, gi in enumerate(test_group_indices): + stats_codes_np[codes == gi] = i + stats_codes_np[codes == ireference] = n_test - rank_sums_np = np.empty((n_test, n_total_genes), dtype=np.float64) - tie_corr_np = np.ones((n_test, n_total_genes), dtype=np.float64) + group_sums = cp.empty((n_groups_stats, n_total_genes), dtype=cp.float64) + group_sq_sums = cp.empty((n_groups_stats, n_total_genes), dtype=cp.float64) + group_nnz = cp.empty((n_groups_stats, n_total_genes), dtype=cp.float64) if host_sparse and X.format == "csc": - _csc_host_fn = ( - _wc.ovo_streaming_csc_host_i64 - if X.indptr.dtype == np.int64 - else _wc.ovo_streaming_csc_host - ) + is_f64 = X.data.dtype == np.float64 + is_i64 = X.indptr.dtype == np.int64 + if is_f64 and is_i64: + _csc_host_fn = _wc.ovo_streaming_csc_host_f64_i64 + elif is_f64: + _csc_host_fn = _wc.ovo_streaming_csc_host_f64 + elif is_i64: + _csc_host_fn = _wc.ovo_streaming_csc_host_i64 + else: + _csc_host_fn = _wc.ovo_streaming_csc_host + data_arr = X.data if is_f64 else X.data.astype(np.float32, copy=False) _csc_host_fn( - X.data.astype(np.float32, copy=False), + data_arr, X.indices.astype(np.int32, copy=False), X.indptr, ref_row_map_np, grp_row_map_np, offsets_np, - rank_sums_np, - tie_corr_np, + stats_codes_np, + rank_sums, + tie_corr_arr, + group_sums, + group_sq_sums, + group_nnz, n_ref=n_ref, n_all_grp=n_all_grp, n_rows=n_cells, n_cols=n_total_genes, n_groups=n_test, + n_groups_stats=n_groups_stats, compute_tie_corr=tie_correct, sub_batch_cols=STREAMING_SUB_BATCH, ) elif host_sparse: csr = X.tocsr() if X.format != "csr" else X - _csr_host_fn = ( - _wc.ovo_streaming_csr_host_i64 - if csr.indptr.dtype == np.int64 - else _wc.ovo_streaming_csr_host - ) + is_f64 = csr.data.dtype == np.float64 + is_i64 = csr.indptr.dtype == np.int64 + if is_f64 and is_i64: + _csr_host_fn = _wc.ovo_streaming_csr_host_f64_i64 + elif is_f64: + _csr_host_fn = _wc.ovo_streaming_csr_host_f64 + elif is_i64: + _csr_host_fn = _wc.ovo_streaming_csr_host_i64 + else: + _csr_host_fn = _wc.ovo_streaming_csr_host + data_arr = csr.data if is_f64 else csr.data.astype(np.float32, copy=False) _csr_host_fn( - csr.data.astype(np.float32, copy=False), + data_arr, csr.indices.astype(np.int32, copy=False), csr.indptr, ref_row_ids_np.astype(np.int32, copy=False), all_grp_row_ids_np.astype(np.int32, copy=False), offsets_np, - rank_sums_np, - tie_corr_np, + stats_codes_np, + rank_sums, + tie_corr_arr, + group_sums, + group_sq_sums, + group_nnz, n_ref=n_ref, n_all_grp=n_all_grp, n_rows=n_cells, n_cols=n_total_genes, n_groups=n_test, + n_groups_stats=n_groups_stats, nnz=csr.nnz, compute_tie_corr=tie_correct, sub_batch_cols=STREAMING_SUB_BATCH, ) else: - _wc.ovo_streaming_dense_host( - np.asfortranarray(X.astype(np.float32, copy=False)), + is_f64 = X.dtype == np.float64 + _dense_host_fn = ( + _wc.ovo_streaming_dense_host_f64 + if is_f64 + else _wc.ovo_streaming_dense_host + ) + block = X if is_f64 else X.astype(np.float32, copy=False) + _dense_host_fn( + np.asfortranarray(block), ref_row_ids_np.astype(np.int32, copy=False), all_grp_row_ids_np.astype(np.int32, copy=False), offsets_np, - rank_sums_np, - tie_corr_np, + stats_codes_np, + rank_sums, + tie_corr_arr, + group_sums, + group_sq_sums, + group_nnz, n_ref=n_ref, n_all_grp=n_all_grp, n_rows=n_cells, n_cols=n_total_genes, n_groups=n_test, + n_groups_stats=n_groups_stats, compute_tie_corr=tie_correct, sub_batch_cols=STREAMING_SUB_BATCH, ) - rank_sums = cp.asarray(rank_sums_np) - tie_corr_arr = cp.asarray(tie_corr_np) + if rg._compute_stats_in_chunks: + _fill_ovo_stats_from_accumulators( + rg, + group_sums, + group_sq_sums, + group_nnz, + group_sizes=group_sizes, + test_group_indices=test_group_indices, + n_ref=n_ref, + ) else: # ---- GPU path: transfer once, then dispatch ---- From 12de3b0680069c4a01bfd7394d5ad5796f5f53fb Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 20 Apr 2026 14:10:26 +0200 Subject: [PATCH 19/21] update kernels --- .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 115 ++++---- .../_cuda/wilcoxon/wilcoxon_common.cuh | 61 ++++- .../_cuda/wilcoxon/wilcoxon_ovo.cu | 254 +++++++++--------- .../_cuda/wilcoxon/wilcoxon_ovr.cu | 204 ++++++++++---- .../tools/_rank_genes_groups/_wilcoxon.py | 11 +- 5 files changed, 385 insertions(+), 260 deletions(-) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index 0c6b7505..a6dc74f9 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -10,57 +10,35 @@ * tie groups by adjacent comparison (sequential access, no binary search). * Cross-boundary ties are resolved via binary search at chunk boundaries. * - * When use_gmem is false (default), per-group accumulators live in shared - * memory (fast atomics, limited to ~750 groups on 48 KB devices). - * When use_gmem is true, accumulators write directly to the output arrays - * in global memory, supporting an arbitrary number of groups. The caller - * must pre-zero rank_sums (and group_sums/group_sq_sums/group_nnz if - * compute_stats) before launching. + * When use_gmem is false, per-group accumulators live in shared memory + * (fast atomics, limited to ~1500 groups on 48 KB devices). When use_gmem + * is true, accumulators write directly to ``rank_sums`` in global memory, + * supporting an arbitrary number of groups. The caller must pre-zero + * ``rank_sums`` before launching in the gmem path. * * Shared memory layout: - * use_gmem=false: (4 * n_groups + 32) doubles (accumulators + warp buf) - * use_gmem=true: 32 doubles (warp buf only) + * use_gmem=false: (n_groups + 32) doubles (accumulators + warp buf) + * use_gmem=true: 32 doubles (warp buf only) */ __global__ void rank_sums_from_sorted_kernel( const float* __restrict__ sorted_vals, const int* __restrict__ sorted_row_idx, const int* __restrict__ group_codes, - double* __restrict__ rank_sums, double* __restrict__ tie_corr, - double* __restrict__ group_sums, double* __restrict__ group_sq_sums, - double* __restrict__ group_nnz, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, bool compute_stats, bool use_gmem) { + double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, bool use_gmem) { int col = blockIdx.x; if (col >= n_cols) return; extern __shared__ double smem[]; - // Accumulator pointers: shared memory (fast) or global memory (large - // groups) double* grp_sums; - double* s_sum; - double* s_sq; - double* s_nnz; - if (use_gmem) { - // Global memory path: write directly to output arrays (must be - // pre-zeroed) + // Global memory path: write directly to output (must be pre-zeroed) grp_sums = rank_sums + (size_t)col; // stride: n_cols - s_sum = group_sums ? group_sums + (size_t)col : nullptr; - s_sq = group_sq_sums ? group_sq_sums + (size_t)col : nullptr; - s_nnz = group_nnz ? group_nnz + (size_t)col : nullptr; } else { // Shared memory path: per-block accumulators grp_sums = smem; - s_sum = smem + n_groups; - s_sq = smem + 2 * n_groups; - s_nnz = smem + 3 * n_groups; - for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { grp_sums[g] = 0.0; - if (compute_stats) { - s_sum[g] = 0.0; - s_sq[g] = 0.0; - s_nnz[g] = 0.0; - } } __syncthreads(); } @@ -104,7 +82,7 @@ __global__ void rank_sums_from_sorted_kernel( sv[tie_local_end] == val) { int lo = tie_local_end, hi = n_rows - 1; while (lo < hi) { - int mid = (lo + hi + 1) / 2; + int mid = hi - ((hi - lo) >> 1); if (sv[mid] > val) hi = mid - 1; else @@ -120,12 +98,6 @@ __global__ void rank_sums_from_sorted_kernel( int grp = group_codes[si[j]]; if (grp < n_groups) { atomicAdd(&grp_sums[grp * acc_stride], avg_rank); - if (compute_stats) { - double v = (double)sv[j]; - atomicAdd(&s_sum[grp * acc_stride], v); - atomicAdd(&s_sq[grp * acc_stride], v * v); - if (v != 0.0) atomicAdd(&s_nnz[grp * acc_stride], 1.0); - } } } @@ -143,19 +115,14 @@ __global__ void rank_sums_from_sorted_kernel( if (!use_gmem) { for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { rank_sums[(size_t)g * n_cols + col] = grp_sums[g]; - if (compute_stats) { - group_sums[(size_t)g * n_cols + col] = s_sum[g]; - group_sq_sums[(size_t)g * n_cols + col] = s_sq[g]; - group_nnz[(size_t)g * n_cols + col] = s_nnz[g]; - } } } if (compute_tie_corr) { - // Warp buf sits after all accumulator arrays in shared memory. - // gmem path: accumulators are in global mem, warp buf starts at - // smem[0]. smem path: 4 arrays of n_groups doubles, then warp buf. - int warp_buf_off = use_gmem ? 0 : (compute_stats ? 4 : 1) * n_groups; + // Warp buf sits after accumulator array in shared memory. + // gmem path: warp buf starts at smem[0]. + // smem path: n_groups doubles, then warp buf. + int warp_buf_off = use_gmem ? 0 : n_groups; double* warp_buf = smem + warp_buf_off; #pragma unroll for (int off = 16; off > 0; off >>= 1) @@ -209,8 +176,9 @@ __global__ void rank_sums_sparse_ovr_kernel( const int* __restrict__ sorted_row_idx, const int* __restrict__ col_seg_offsets, const int* __restrict__ group_codes, const double* __restrict__ group_sizes, - double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_rows, - int sb_cols, int n_groups, bool compute_tie_corr) { + double* __restrict__ rank_sums, double* __restrict__ tie_corr, + double* __restrict__ nz_count_scratch, int n_rows, int sb_cols, + int n_groups, bool compute_tie_corr, bool use_gmem) { int col = blockIdx.x; if (col >= sb_cols) return; @@ -222,14 +190,27 @@ __global__ void rank_sums_sparse_ovr_kernel( const int* si = sorted_row_idx + seg_start; extern __shared__ double smem[]; - double* grp_sums = smem; - double* grp_nz_count = smem + n_groups; + double* grp_sums; + double* grp_nz_count; + // Accumulator stride: 1 for shared mem (dense per-block), sb_cols for + // gmem (row-major layout (n_groups, sb_cols) shared across blocks). + int acc_stride; - for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { - grp_sums[g] = 0.0; - grp_nz_count[g] = 0.0; + if (use_gmem) { + // Output rank_sums doubles as accumulator (pre-zeroed by caller). + grp_sums = rank_sums + (size_t)col; + grp_nz_count = nz_count_scratch + (size_t)col; + acc_stride = sb_cols; + } else { + grp_sums = smem; + grp_nz_count = smem + n_groups; + acc_stride = 1; + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; + grp_nz_count[g] = 0.0; + } + __syncthreads(); } - __syncthreads(); // --- Find zero range: neg_end = first val >= 0, pos_start = first val > 0 // --- @@ -278,7 +259,7 @@ __global__ void rank_sums_sparse_ovr_kernel( if (i < neg_end || i >= pos_start) { // skip stored zeros int grp = group_codes[si[i]]; if (grp < n_groups) { - atomicAdd(&grp_nz_count[grp], 1.0); + atomicAdd(&grp_nz_count[grp * acc_stride], 1.0); } } } @@ -286,8 +267,8 @@ __global__ void rank_sums_sparse_ovr_kernel( // --- Zero-rank contribution per group --- for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { - double n_zero_in_g = group_sizes[g] - grp_nz_count[g]; - grp_sums[g] = n_zero_in_g * zero_avg_rank; + double n_zero_in_g = group_sizes[g] - grp_nz_count[g * acc_stride]; + grp_sums[g * acc_stride] = n_zero_in_g * zero_avg_rank; } __syncthreads(); @@ -343,7 +324,7 @@ __global__ void rank_sums_sparse_ovr_kernel( int search_hi = (val < 0.0f) ? (neg_end - 1) : (nnz_stored - 1); int lo = tie_local_end, hi = search_hi; while (lo < hi) { - int mid = (lo + hi + 1) >> 1; + int mid = hi - ((hi - lo) >> 1); if (sv[mid] > val) hi = mid - 1; else @@ -368,7 +349,7 @@ __global__ void rank_sums_sparse_ovr_kernel( for (int j = i; j < tie_local_end; ++j) { int grp = group_codes[si[j]]; if (grp < n_groups) { - atomicAdd(&grp_sums[grp], avg_rank); + atomicAdd(&grp_sums[grp * acc_stride], avg_rank); } } @@ -382,9 +363,11 @@ __global__ void rank_sums_sparse_ovr_kernel( __syncthreads(); - // Write rank sums to global output - for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { - rank_sums[(size_t)g * sb_cols + col] = grp_sums[g]; + // Write rank sums to global output (smem path only — gmem path is direct) + if (!use_gmem) { + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * sb_cols + col] = grp_sums[g]; + } } // Tie correction: warp + block reduction @@ -395,7 +378,9 @@ __global__ void rank_sums_sparse_ovr_kernel( local_tie_sum += tz * tz * tz - tz; } - int warp_buf_off = 2 * n_groups; + // smem path: warp buf after both accumulator arrays (2 * n_groups). + // gmem path: accumulators are in gmem, warp buf starts at smem[0]. + int warp_buf_off = use_gmem ? 0 : 2 * n_groups; double* warp_buf = smem + warp_buf_off; #pragma unroll diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh index 4363fd49..5ac8ade7 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh @@ -11,12 +11,51 @@ #include // rmm 25.x #endif +#include "../nb_types.h" // for CUDA_CHECK_LAST_ERROR + constexpr int WARP_SIZE = 32; constexpr int MAX_THREADS_PER_BLOCK = 512; constexpr int N_STREAMS = 4; constexpr int SUB_BATCH_COLS = 64; constexpr int BEGIN_BIT = 0; constexpr int END_BIT = 32; +// Default thread-per-block for utility kernels (extract, gather, offsets, +// etc.). +constexpr int UTIL_BLOCK_SIZE = 256; +// Scratch slots for warp-level reduction (one slot per warp, 32 warps max). +constexpr int WARP_REDUCE_BUF = 32; +// Max group size for the fused smem-sort rank kernel (Tier 1 fast path). +// Beyond this, fall back to CUB segmented sort + binary-search rank kernel. +constexpr int TIER1_GROUP_THRESHOLD = 2500; + +// --------------------------------------------------------------------------- +// RAII guard for cudaHostRegister. Unregisters on scope exit even when an +// exception unwinds — prevents leaked host pinning on stream-sync failures. +// --------------------------------------------------------------------------- +struct HostRegisterGuard { + void* ptr = nullptr; + + HostRegisterGuard() = default; + HostRegisterGuard(void* p, size_t bytes, unsigned int flags = 0) : ptr(p) { + if (ptr) cudaHostRegister(ptr, bytes, flags); + } + ~HostRegisterGuard() { + if (ptr) cudaHostUnregister(ptr); + } + HostRegisterGuard(const HostRegisterGuard&) = delete; + HostRegisterGuard& operator=(const HostRegisterGuard&) = delete; + HostRegisterGuard(HostRegisterGuard&& other) noexcept : ptr(other.ptr) { + other.ptr = nullptr; + } + HostRegisterGuard& operator=(HostRegisterGuard&& other) noexcept { + if (this != &other) { + if (ptr) cudaHostUnregister(ptr); + ptr = other.ptr; + other.ptr = nullptr; + } + return *this; + } +}; // --------------------------------------------------------------------------- // RMM pool helper — allocate GPU buffers through the current RMM memory @@ -42,14 +81,24 @@ static inline int round_up_to_warp(int n) { return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; } -/** Upload linear segment offsets [0, stride, 2*stride, ...] to device. - * Uses synchronous copy — the buffer is small (a few hundred bytes). */ +/** Fill linear segment offsets [0, stride, 2*stride, ..., n_segments*stride] + * on-device. One thread per output slot. */ +__global__ void fill_linear_offsets_kernel(int* __restrict__ out, + int n_segments, int stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i <= n_segments) out[i] = i * stride; +} + +/** Fill linear segment offsets [0, stride, 2*stride, ...] on device. + * Runs on the supplied stream so it doesn't serialize multi-stream pipelines. + */ static inline void upload_linear_offsets(int* d_offsets, int n_segments, int stride, cudaStream_t stream) { - std::vector h(n_segments + 1); - for (int i = 0; i <= n_segments; i++) h[i] = i * stride; - cudaMemcpy(d_offsets, h.data(), (n_segments + 1) * sizeof(int), - cudaMemcpyHostToDevice); + int count = n_segments + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + fill_linear_offsets_kernel<<>>( + d_offsets, n_segments, stride); + CUDA_CHECK_LAST_ERROR(fill_linear_offsets_kernel); } // ============================================================================ diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu index a2a513cf..01722345 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu @@ -66,6 +66,38 @@ static size_t get_seg_sort_temp_bytes(int n_items, int n_segments) { return bytes; } +/** + * Tier 1 dispatch: when the largest group fits in shared memory, a fused + * bitonic-sort + binary-search kernel handles the whole group per block. + * Otherwise we fall back to CUB segmented sort plus the pre-sorted rank + * kernel. This struct bundles the sizing knobs derived from the host-side + * group offsets so each streaming impl can drop a 15-line prep block. + */ +struct Tier1Config { + int max_grp_size = 0; + bool use_tier1 = false; + int padded_grp_size = 0; + int tier1_tpb = 0; + size_t tier1_smem = 0; +}; + +static Tier1Config make_tier1_config(const int* h_grp_offsets, int n_groups) { + Tier1Config c; + for (int g = 0; g < n_groups; g++) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (sz > c.max_grp_size) c.max_grp_size = sz; + } + c.use_tier1 = (c.max_grp_size <= TIER1_GROUP_THRESHOLD); + if (c.use_tier1) { + c.padded_grp_size = 1; + while (c.padded_grp_size < c.max_grp_size) c.padded_grp_size <<= 1; + c.tier1_tpb = std::min(c.padded_grp_size, MAX_THREADS_PER_BLOCK); + c.tier1_smem = (size_t)c.padded_grp_size * sizeof(float) + + WARP_REDUCE_BUF * sizeof(double); + } + return c; +} + /** * Streaming OVO pipeline. * @@ -146,9 +178,10 @@ static void ovo_streaming_impl(const float* ref_sorted, const float* grp_data, // Build segment offsets on device { int total = sb_n_seg + 1; - int blk = (total + 255) / 256; - build_seg_offsets_kernel<<>>( + int blk = (total + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_seg_offsets_kernel<<>>( grp_offsets, buf.seg_offsets, n_all_grp, n_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_seg_offsets_kernel); } // Sort group data for this sub-batch @@ -165,6 +198,7 @@ static void ovo_streaming_impl(const float* ref_sorted, const float* grp_data, ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); // Scatter sub-batch results to global output cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), @@ -210,16 +244,15 @@ static void ovo_streaming_csr_impl( if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; // ---- Tier dispatch: read group offsets to determine max group size ---- - constexpr int TIER1_THRESHOLD = 2500; std::vector h_offsets(n_groups + 1); cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), cudaMemcpyDeviceToHost); - int max_grp_size = 0; - for (int g = 0; g < n_groups; g++) { - int sz = h_offsets[g + 1] - h_offsets[g]; - if (sz > max_grp_size) max_grp_size = sz; - } - bool use_tier1 = (max_grp_size <= TIER1_THRESHOLD); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; int n_streams = N_STREAMS; if (n_cols < n_streams * sub_batch_cols) @@ -250,17 +283,6 @@ static void ovo_streaming_csr_impl( cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } - // Tier 1 precomputation - int padded_grp_size = 0; - int tier1_tpb = 0; - size_t tier1_smem = 0; - if (use_tier1) { - padded_grp_size = 1; - while (padded_grp_size < max_grp_size) padded_grp_size <<= 1; - tier1_tpb = std::min(padded_grp_size, MAX_THREADS_PER_BLOCK); - tier1_smem = padded_grp_size * sizeof(float) + 32 * sizeof(double); - } - std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); @@ -320,6 +342,7 @@ static void ovo_streaming_csr_impl( csr_extract_dense_kernel<<>>( csr_data, csr_indices, csr_indptr, ref_row_ids, buf.ref_dense, n_ref, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); } upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); { @@ -338,6 +361,7 @@ static void ovo_streaming_csr_impl( csr_extract_dense_kernel<<>>( csr_data, csr_indices, csr_indptr, grp_row_ids, buf.grp_dense, n_all_grp, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); } if (use_tier1) { @@ -347,15 +371,17 @@ static void ovo_streaming_csr_impl( buf.ref_sorted, buf.grp_dense, grp_offsets, buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, padded_grp_size); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else { // ---- Tier 3: CUB segmented sort + binary search ---- int sb_grp_seg = n_groups * sb_cols; { int total = sb_grp_seg + 1; - int blk = (total + 255) / 256; - build_seg_offsets_kernel<<>>( + int blk = (total + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_seg_offsets_kernel<<>>( grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_seg_offsets_kernel); } { size_t temp = cub_temp_bytes; @@ -371,6 +397,7 @@ static void ovo_streaming_csr_impl( buf.ref_sorted, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); } } @@ -415,16 +442,15 @@ static void ovo_streaming_csc_impl( if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; // ---- Tier dispatch ---- - constexpr int TIER1_THRESHOLD = 2500; std::vector h_offsets(n_groups + 1); cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), cudaMemcpyDeviceToHost); - int max_grp_size = 0; - for (int g = 0; g < n_groups; g++) { - int sz = h_offsets[g + 1] - h_offsets[g]; - if (sz > max_grp_size) max_grp_size = sz; - } - bool use_tier1 = (max_grp_size <= TIER1_THRESHOLD); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; int n_streams = N_STREAMS; if (n_cols < n_streams * sub_batch_cols) @@ -454,17 +480,6 @@ static void ovo_streaming_csc_impl( cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } - // Tier 1 precomputation - int padded_grp_size = 0; - int tier1_tpb = 0; - size_t tier1_smem = 0; - if (use_tier1) { - padded_grp_size = 1; - while (padded_grp_size < max_grp_size) padded_grp_size <<= 1; - tier1_tpb = std::min(padded_grp_size, MAX_THREADS_PER_BLOCK); - tier1_smem = padded_grp_size * sizeof(float) + 32 * sizeof(double); - } - std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); @@ -517,9 +532,10 @@ static void ovo_streaming_csc_impl( // ---- Extract ref from CSC via row_map, then sort ---- cudaMemsetAsync(buf.ref_dense, 0, sb_ref_items_actual * sizeof(float), stream); - csc_extract_mapped_kernel<<>>( + csc_extract_mapped_kernel<<>>( csc_data, csc_indices, csc_indptr, ref_row_map, buf.ref_dense, n_ref, col); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); { size_t temp = cub_temp_bytes; @@ -532,9 +548,10 @@ static void ovo_streaming_csc_impl( // ---- Extract grp from CSC via row_map ---- cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), stream); - csc_extract_mapped_kernel<<>>( + csc_extract_mapped_kernel<<>>( csc_data, csc_indices, csc_indptr, grp_row_map, buf.grp_dense, n_all_grp, col); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); if (use_tier1) { dim3 grid(sb_cols, n_groups); @@ -542,14 +559,16 @@ static void ovo_streaming_csc_impl( buf.ref_sorted, buf.grp_dense, grp_offsets, buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, padded_grp_size); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else { int sb_grp_seg = n_groups * sb_cols; { int total = sb_grp_seg + 1; - int blk = (total + 255) / 256; - build_seg_offsets_kernel<<>>( + int blk = (total + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_seg_offsets_kernel<<>>( grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_seg_offsets_kernel); } { size_t temp = cub_temp_bytes; @@ -565,6 +584,7 @@ static void ovo_streaming_csc_impl( buf.ref_sorted, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); } } @@ -613,13 +633,12 @@ static void ovo_streaming_csc_host_impl( if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; // ---- Tier dispatch from host offsets ---- - constexpr int TIER1_THRESHOLD = 2500; - int max_grp_size = 0; - for (int g = 0; g < n_groups; g++) { - int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; - if (sz > max_grp_size) max_grp_size = sz; - } - bool use_tier1 = (max_grp_size <= TIER1_THRESHOLD); + auto t1 = make_tier1_config(h_grp_offsets, n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; int n_streams = N_STREAMS; if (n_cols < n_streams * sub_batch_cols) @@ -649,16 +668,6 @@ static void ovo_streaming_csc_host_impl( cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } - int padded_grp_size = 0; - int tier1_tpb = 0; - size_t tier1_smem = 0; - if (use_tier1) { - padded_grp_size = 1; - while (padded_grp_size < max_grp_size) padded_grp_size <<= 1; - tier1_tpb = std::min(padded_grp_size, MAX_THREADS_PER_BLOCK); - tier1_smem = padded_grp_size * sizeof(float) + 32 * sizeof(double); - } - // Max nnz across any sub-batch for sparse transfer buffer sizing size_t max_nnz = 0; for (int c = 0; c < n_cols; c += sub_batch_cols) { @@ -741,8 +750,10 @@ static void ovo_streaming_csc_host_impl( // Pin only the sparse input arrays; outputs live on the device. size_t total_nnz = (size_t)h_indptr[n_cols]; - cudaHostRegister(const_cast(h_data), total_nnz * sizeof(InT), 0); - cudaHostRegister(const_cast(h_indices), total_nnz * sizeof(int), 0); + HostRegisterGuard _pin_data(const_cast(h_data), + total_nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(int)); int col = 0; int batch_idx = 0; @@ -772,18 +783,20 @@ static void ovo_streaming_csc_host_impl( // ---- Cast to float32 for sort + accumulate stats in float64 ---- ovr_cast_and_accumulate_sparse_kernel - <<>>( + <<>>( buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, d_stats_codes, buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, n_groups_stats); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); // ---- Extract ref from CSC via row_map, sort ---- cudaMemsetAsync(buf.ref_dense, 0, sb_ref_actual * sizeof(float), stream); - csc_extract_mapped_kernel<<>>( + csc_extract_mapped_kernel<<>>( buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, d_ref_row_map, buf.ref_dense, n_ref, 0); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); { size_t temp = cub_temp_bytes; @@ -796,9 +809,10 @@ static void ovo_streaming_csc_host_impl( // ---- Extract grp from CSC via row_map ---- cudaMemsetAsync(buf.grp_dense, 0, sb_grp_actual * sizeof(float), stream); - csc_extract_mapped_kernel<<>>( + csc_extract_mapped_kernel<<>>( buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, d_grp_row_map, buf.grp_dense, n_all_grp, 0); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); // ---- Tier dispatch: sort grp + rank ---- if (use_tier1) { @@ -807,14 +821,16 @@ static void ovo_streaming_csc_host_impl( buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, padded_grp_size); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else { int sb_grp_seg = n_groups * sb_cols; { int total = sb_grp_seg + 1; - int blk = (total + 255) / 256; - build_seg_offsets_kernel<<>>( + int blk = (total + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_seg_offsets_kernel<<>>( d_grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_seg_offsets_kernel); } { size_t temp = cub_temp_bytes; @@ -830,6 +846,7 @@ static void ovo_streaming_csc_host_impl( buf.ref_sorted, buf.grp_sorted, d_grp_offsets, buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); } } @@ -869,9 +886,6 @@ static void ovo_streaming_csc_host_impl( cudaGetErrorString(err)); } - cudaHostUnregister(const_cast(h_data)); - cudaHostUnregister(const_cast(h_indices)); - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -898,13 +912,12 @@ static void ovo_streaming_csr_host_impl( if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; // ---- Tier dispatch from host offsets ---- - constexpr int TIER1_THRESHOLD = 2500; - int max_grp_size = 0; - for (int g = 0; g < n_groups; g++) { - int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; - if (sz > max_grp_size) max_grp_size = sz; - } - bool use_tier1 = (max_grp_size <= TIER1_THRESHOLD); + auto t1 = make_tier1_config(h_grp_offsets, n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; int n_streams = N_STREAMS; if (n_cols < n_streams * sub_batch_cols) @@ -934,16 +947,6 @@ static void ovo_streaming_csr_host_impl( cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } - int padded_grp_size = 0; - int tier1_tpb = 0; - size_t tier1_smem = 0; - if (use_tier1) { - padded_grp_size = 1; - while (padded_grp_size < max_grp_size) padded_grp_size <<= 1; - tier1_tpb = std::min(padded_grp_size, MAX_THREADS_PER_BLOCK); - tier1_smem = padded_grp_size * sizeof(float) + 32 * sizeof(double); - } - std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); @@ -959,8 +962,9 @@ static void ovo_streaming_csr_host_impl( int* d_grp_offsets = pool.alloc(n_groups + 1); int* d_stats_codes = pool.alloc(n_rows); - cudaHostRegister(const_cast(h_data), nnz * sizeof(InT), 0); - cudaHostRegister(const_cast(h_indices), nnz * sizeof(int), 0); + HostRegisterGuard _pin_data(const_cast(h_data), nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + nnz * sizeof(int)); cudaMemcpyAsync(d_data_orig, h_data, nnz * sizeof(InT), cudaMemcpyHostToDevice, streams[0]); cudaMemcpyAsync(d_indices, h_indices, nnz * sizeof(int), @@ -991,12 +995,13 @@ static void ovo_streaming_csr_host_impl( (size_t)n_groups_stats * n_cols * sizeof(double), streams[0]); { - int tpb = 256; - int blk = (n_rows + tpb - 1) / tpb; - cast_and_accumulate_csr_kernel<<>>( - d_data_orig, d_data, d_indices, d_indptr, d_stats_codes, - d_group_sums, d_group_sq_sums, d_group_nnz, n_rows, n_cols, - n_groups_stats); + int blk = (n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + cast_and_accumulate_csr_kernel + <<>>( + d_data_orig, d_data, d_indices, d_indptr, d_stats_codes, + d_group_sums, d_group_sq_sums, d_group_nnz, n_rows, n_cols, + n_groups_stats); + CUDA_CHECK_LAST_ERROR(cast_and_accumulate_csr_kernel); } cudaStreamSynchronize(streams[0]); @@ -1010,6 +1015,7 @@ static void ovo_streaming_csr_host_impl( csr_extract_dense_kernel<<>>(d_data, d_indices, d_indptr, d_ref_row_ids, ref_dense, n_ref, 0, n_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); } { int* ref_seg = pool.alloc(n_cols + 1); @@ -1070,6 +1076,7 @@ static void ovo_streaming_csr_host_impl( csr_extract_dense_kernel<<>>( d_data, d_indices, d_indptr, d_grp_row_ids, buf.grp_dense, n_all_grp, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); } // Rank against pre-sorted ref (just slice into ref_sorted) @@ -1080,14 +1087,16 @@ static void ovo_streaming_csr_host_impl( ref_sub, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, padded_grp_size); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else { int sb_grp_seg = n_groups * sb_cols; { int total = sb_grp_seg + 1; - int blk = (total + 255) / 256; - build_seg_offsets_kernel<<>>( + int blk = (total + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_seg_offsets_kernel<<>>( d_grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_seg_offsets_kernel); } { size_t temp = cub_temp_bytes; @@ -1103,6 +1112,7 @@ static void ovo_streaming_csr_host_impl( ref_sub, buf.grp_sorted, d_grp_offsets, buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); } } @@ -1130,9 +1140,6 @@ static void ovo_streaming_csr_host_impl( cudaGetErrorString(err)); } - cudaHostUnregister(const_cast(h_data)); - cudaHostUnregister(const_cast(h_indices)); - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -1172,13 +1179,12 @@ static void ovo_streaming_dense_host_impl( if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; // ---- Tier dispatch from host offsets ---- - constexpr int TIER1_THRESHOLD = 2500; - int max_grp_size = 0; - for (int g = 0; g < n_groups; g++) { - int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; - if (sz > max_grp_size) max_grp_size = sz; - } - bool use_tier1 = (max_grp_size <= TIER1_THRESHOLD); + auto t1 = make_tier1_config(h_grp_offsets, n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; int n_streams = N_STREAMS; if (n_cols < n_streams * sub_batch_cols) @@ -1209,16 +1215,6 @@ static void ovo_streaming_dense_host_impl( cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } - int padded_grp_size = 0; - int tier1_tpb = 0; - size_t tier1_smem = 0; - if (use_tier1) { - padded_grp_size = 1; - while (padded_grp_size < max_grp_size) padded_grp_size <<= 1; - tier1_tpb = std::min(padded_grp_size, MAX_THREADS_PER_BLOCK); - tier1_smem = padded_grp_size * sizeof(float) + 32 * sizeof(double); - } - std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); @@ -1288,8 +1284,8 @@ static void ovo_streaming_dense_host_impl( size_t smem_cast = (size_t)(3 * n_groups_stats) * sizeof(double); // Pin only the host input; outputs live on the device. - cudaHostRegister(const_cast(h_block), - (size_t)n_rows * n_cols * sizeof(InT), 0); + HostRegisterGuard _pin_block(const_cast(h_block), + (size_t)n_rows * n_cols * sizeof(InT)); int col = 0; int batch_idx = 0; @@ -1308,15 +1304,17 @@ static void ovo_streaming_dense_host_impl( // ---- Cast to float32 for sort + accumulate stats in float64 ---- ovr_cast_and_accumulate_dense_kernel - <<>>( + <<>>( buf.d_block_orig, buf.d_block_f32, d_stats_codes, buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, n_rows, sb_cols, n_groups_stats); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_kernel); // ---- Gather ref rows, sort ---- - dense_gather_rows_kernel<<>>( + dense_gather_rows_kernel<<>>( buf.d_block_f32, d_ref_row_ids, buf.ref_dense, n_rows, n_ref, sb_cols); + CUDA_CHECK_LAST_ERROR(dense_gather_rows_kernel); upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); { size_t temp = cub_temp_bytes; @@ -1327,9 +1325,10 @@ static void ovo_streaming_dense_host_impl( } // ---- Gather grp rows ---- - dense_gather_rows_kernel<<>>( + dense_gather_rows_kernel<<>>( buf.d_block_f32, d_grp_row_ids, buf.grp_dense, n_rows, n_all_grp, sb_cols); + CUDA_CHECK_LAST_ERROR(dense_gather_rows_kernel); // ---- Tier dispatch: sort grp + rank ---- if (use_tier1) { @@ -1338,14 +1337,16 @@ static void ovo_streaming_dense_host_impl( buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, padded_grp_size); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else { int sb_grp_seg = n_groups * sb_cols; { int total = sb_grp_seg + 1; - int blk = (total + 255) / 256; - build_seg_offsets_kernel<<>>( + int blk = (total + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_seg_offsets_kernel<<>>( d_grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_seg_offsets_kernel); } { size_t temp = cub_temp_bytes; @@ -1361,6 +1362,7 @@ static void ovo_streaming_dense_host_impl( buf.ref_sorted, buf.grp_sorted, d_grp_offsets, buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); } } @@ -1400,8 +1402,6 @@ static void ovo_streaming_dense_host_impl( cudaGetErrorString(err)); } - cudaHostUnregister(const_cast(h_block)); - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu index 58fda6d0..6c261c15 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu @@ -9,11 +9,12 @@ using namespace nb::literals; -/** Rebase a slice of indptr: out[i] = indptr[col + i] - indptr[col]. */ +/** Rebase a slice of indptr: out[i] = indptr[col + i] - indptr[col]. + * Grid-strided: supports arbitrary `count` (no single-block thread limit). */ __global__ void rebase_indptr_kernel(const int* __restrict__ indptr, int* __restrict__ out, int col, int count) { - int i = threadIdx.x; + int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < count) out[i] = indptr[col + i] - indptr[col]; } @@ -75,16 +76,20 @@ __global__ void csr_scatter_to_csc_kernel( * Decide whether to use shared or global memory for OVR rank accumulators. * Returns the smem size to request and sets use_gmem accordingly. */ -static size_t ovr_smem_config(int n_groups, bool& use_gmem) { - size_t need = (size_t)(4 * n_groups + 32) * sizeof(double); - static int max_smem = -1; - if (max_smem < 0) { +static int query_max_smem_per_block() { + static int cached = -1; + if (cached < 0) { int device; cudaGetDevice(&device); - cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlock, + cudaDeviceGetAttribute(&cached, cudaDevAttrMaxSharedMemoryPerBlock, device); } - if ((int)need <= max_smem) { + return cached; +} + +static size_t ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(n_groups + 32) * sizeof(double); + if ((int)need <= query_max_smem_per_block()) { use_gmem = false; return need; } @@ -93,6 +98,20 @@ static size_t ovr_smem_config(int n_groups, bool& use_gmem) { return 32 * sizeof(double); } +/** + * Decide smem-vs-gmem for the sparse OVR rank kernel. Two accumulator + * arrays (grp_sums + grp_nz_count) of size n_groups each plus warp buf. + */ +static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(2 * n_groups + 32) * sizeof(double); + if ((int)need <= query_max_smem_per_block()) { + use_gmem = false; + return need; + } + use_gmem = true; + return 32 * sizeof(double); +} + /** * Fill sort values with row indices [0,1,...,n_rows-1] per column. * Grid: (n_cols,), block: 256 threads. @@ -180,8 +199,9 @@ static void ovr_streaming_impl(const float* block, const int* group_codes, // Fill segment offsets + row indices upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); - fill_row_indices_kernel<<>>(buf.vals_in, - n_rows, sb_cols); + fill_row_indices_kernel<<>>( + buf.vals_in, n_rows, sb_cols); + CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); // Sort: keys = block columns [col, col+sb_cols), already F-order const float* keys_in = block + (long long)col * n_rows; @@ -199,8 +219,9 @@ static void ovr_streaming_impl(const float* block, const int* group_codes, } rank_sums_from_sorted_kernel<<>>( buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, - buf.sub_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, - n_groups, compute_tie_corr, false, use_gmem); + buf.sub_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, + use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); // Copy sub-batch results to global output (row-major scatter) // rank_sums is (n_groups, n_cols) row-major: group g, col c → @@ -289,6 +310,7 @@ static void ovr_sparse_csc_host_streaming_impl( double* d_group_sums; double* d_group_sq_sums; double* d_group_nnz; + double* d_nz_scratch; // gmem-only; non-null when rank_use_gmem }; std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { @@ -316,14 +338,45 @@ static void ovr_sparse_csc_host_streaming_impl( cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), cudaMemcpyHostToDevice); - int tpb = 256; - size_t smem_bytes = (size_t)(2 * n_groups + 32) * sizeof(double); + // Pre-compute rebased per-batch offsets and upload once (avoids per-batch + // H2D copy from a transient host buffer). + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb = std::min(sub_batch_cols, n_cols - col_start); + IndptrT ptr_start = h_indptr[col_start]; + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i <= sb; i++) + off[i] = (int)(h_indptr[col_start + i] - ptr_start); + } + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); size_t smem_cast = (size_t)(3 * n_groups) * sizeof(double); + // In gmem mode the sparse rank kernel accumulates into rank_sums directly + // and needs a per-stream nz_count scratch buffer sized (n_groups, sb_cols). + for (int s = 0; s < n_streams; s++) { + if (rank_use_gmem) { + bufs[s].d_nz_scratch = + pool.alloc((size_t)n_groups * sub_batch_cols); + } else { + bufs[s].d_nz_scratch = nullptr; + } + } + // Pin only the host input arrays; outputs live on the device. size_t total_nnz = (size_t)h_indptr[n_cols]; - cudaHostRegister(const_cast(h_data), total_nnz * sizeof(InT), 0); - cudaHostRegister(const_cast(h_indices), total_nnz * sizeof(int), 0); + HostRegisterGuard _pin_data(const_cast(h_data), + total_nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(int)); cudaDeviceSynchronize(); @@ -349,15 +402,10 @@ static void ovr_sparse_csc_host_streaming_impl( cudaMemcpyHostToDevice, stream); } - // Rebase indptr slice on host → int32 per-batch offsets - { - std::vector h_seg(sb_cols + 1); - for (int i = 0; i <= sb_cols; i++) - h_seg[i] = (int)(h_indptr[col + i] - ptr_start); - cudaMemcpyAsync(buf.d_seg_offsets, h_seg.data(), - (sb_cols + 1) * sizeof(int), cudaMemcpyHostToDevice, - stream); - } + // D2D: copy this batch's rebased offsets from the pre-uploaded buffer + int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.d_seg_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); // Cast to float32 for sort + accumulate stats in float64 ovr_cast_and_accumulate_sparse_kernel @@ -366,6 +414,7 @@ static void ovr_sparse_csc_host_streaming_impl( buf.d_sparse_indices, buf.d_seg_offsets, d_group_codes, buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, n_groups); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); // CUB sort only stored nonzeros (float32 keys) if (batch_nnz > 0) { @@ -378,10 +427,19 @@ static void ovr_sparse_csc_host_streaming_impl( } // Sparse rank kernel (stats already captured above) + if (rank_use_gmem) { + cudaMemsetAsync(buf.d_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } rank_sums_sparse_ovr_kernel<<>>( buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, - d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, n_rows, sb_cols, - n_groups, compute_tie_corr); + d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); // D2D: scatter sub-batch results into caller's GPU buffers cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), @@ -418,9 +476,6 @@ static void ovr_sparse_csc_host_streaming_impl( cudaGetErrorString(err)); } - cudaHostUnregister(const_cast(h_data)); - cudaHostUnregister(const_cast(h_indices)); - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -509,12 +564,12 @@ static void ovr_streaming_dense_host_impl( int tpb_rank = round_up_to_warp(n_rows); bool use_gmem = false; size_t smem_rank = ovr_smem_config(n_groups, use_gmem); - int tpb_cast = 256; + int tpb_cast = UTIL_BLOCK_SIZE; size_t smem_cast = (size_t)(3 * n_groups) * sizeof(double); // Pin only the host input. Outputs live on the device (caller-owned). - cudaHostRegister(const_cast(h_block), - (size_t)n_rows * n_cols * sizeof(InT), 0); + HostRegisterGuard _pin_block(const_cast(h_block), + (size_t)n_rows * n_cols * sizeof(InT)); int col = 0; int batch_idx = 0; @@ -535,11 +590,13 @@ static void ovr_streaming_dense_host_impl( buf.d_block_orig, buf.d_block_f32, d_group_codes, buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, n_rows, sb_cols, n_groups); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_kernel); // Fill segment offsets + row indices upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); - fill_row_indices_kernel<<>>(buf.vals_in, - n_rows, sb_cols); + fill_row_indices_kernel<<>>( + buf.vals_in, n_rows, sb_cols); + CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); // Sort size_t temp = cub_temp_bytes; @@ -556,8 +613,9 @@ static void ovr_streaming_dense_host_impl( } rank_sums_from_sorted_kernel<<>>( buf.keys_out, buf.vals_out, d_group_codes, buf.d_rank_sums, - buf.d_tie_corr, nullptr, nullptr, nullptr, n_rows, sb_cols, - n_groups, compute_tie_corr, false, use_gmem); + buf.d_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, + use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); // D2D: scatter sub-batch results into the caller's GPU buffers cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), @@ -594,8 +652,6 @@ static void ovr_streaming_dense_host_impl( cudaGetErrorString(err)); } - cudaHostUnregister(const_cast(h_block)); - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } @@ -640,6 +696,10 @@ static void ovr_sparse_csc_streaming_impl( std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + RmmPool pool; struct StreamBuf { float* keys_out; @@ -648,6 +708,7 @@ static void ovr_sparse_csc_streaming_impl( uint8_t* cub_temp; double* sub_rank_sums; double* sub_tie_corr; + double* d_nz_scratch; // gmem-only }; std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { @@ -658,11 +719,12 @@ static void ovr_sparse_csc_streaming_impl( bufs[s].sub_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; } - int tpb = 256; - size_t smem_bytes = (size_t)(2 * n_groups + 32) * sizeof(double); - cudaDeviceSynchronize(); int col = 0; @@ -679,8 +741,13 @@ static void ovr_sparse_csc_streaming_impl( // Compute rebased segment offsets on GPU (avoids host pinned-buffer // race) - rebase_indptr_kernel<<<1, sb_cols + 1, 0, stream>>>( - csc_indptr, buf.seg_offsets, col, sb_cols + 1); + { + int count = sb_cols + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel<<>>( + csc_indptr, buf.seg_offsets, col, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } // Sort only stored values (keys=data, vals=row_indices) if (batch_nnz > 0) { @@ -693,10 +760,19 @@ static void ovr_sparse_csc_streaming_impl( } // Sparse rank kernel (handles implicit zeros analytically) + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } rank_sums_sparse_ovr_kernel<<>>( buf.keys_out, buf.vals_out, buf.seg_offsets, group_codes, - group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, n_rows, sb_cols, - n_groups, compute_tie_corr); + group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); // Scatter results to global output cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), @@ -751,10 +827,10 @@ static void ovr_sparse_csr_streaming_impl( int* d_col_counts = pool.alloc(n_cols); cudaMemset(d_col_counts, 0, n_cols * sizeof(int)); { - int tpb = 256; - int blocks = (n_rows + tpb - 1) / tpb; - csr_col_histogram_kernel<<>>(csr_indices, csr_indptr, - d_col_counts, n_rows, n_cols); + int blocks = (n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + csr_col_histogram_kernel<<>>( + csr_indices, csr_indptr, d_col_counts, n_rows, n_cols); + CUDA_CHECK_LAST_ERROR(csr_col_histogram_kernel); } std::vector h_col_counts(n_cols); cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(int), @@ -816,6 +892,11 @@ static void ovr_sparse_csr_streaming_impl( std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + int scatter_blocks = (n_rows + tpb - 1) / tpb; + struct StreamBuf { int* col_offsets; // [sub_batch_cols + 1] CSC-style offsets int* write_pos; // [sub_batch_cols] atomic write counters @@ -826,6 +907,7 @@ static void ovr_sparse_csr_streaming_impl( uint8_t* cub_temp; double* sub_rank_sums; double* sub_tie_corr; + double* d_nz_scratch; // gmem-only }; std::vector bufs(n_streams); for (int s = 0; s < n_streams; s++) { @@ -839,12 +921,12 @@ static void ovr_sparse_csr_streaming_impl( bufs[s].sub_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; } - int tpb = 256; - size_t smem_bytes = (size_t)(2 * n_groups + 32) * sizeof(double); - int scatter_blocks = (n_rows + tpb - 1) / tpb; - cudaDeviceSynchronize(); // ---- Phase 2: Stream loop ---- @@ -870,6 +952,7 @@ static void ovr_sparse_csr_streaming_impl( csr_scatter_to_csc_kernel<<>>( csr_data, csr_indices, csr_indptr, buf.write_pos, buf.csc_vals, buf.csc_row_idx, n_rows, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); // CUB sort only the nonzeros size_t temp = cub_temp_bytes; @@ -880,10 +963,19 @@ static void ovr_sparse_csr_streaming_impl( } // Sparse rank kernel (handles implicit zeros analytically) + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } rank_sums_sparse_ovr_kernel<<>>( buf.keys_out, buf.vals_out, buf.col_offsets, group_codes, - group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, n_rows, sb_cols, - n_groups, compute_tie_corr); + group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); // Scatter results to global output cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 52d1fcb6..1773c5cd 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -133,13 +133,12 @@ def _to_gpu_native(X, n_rows: int, n_cols: int): # Downcast indices to int32 on host before transfer (column indices # always fit in int32; scipy may use int64 when nnz > 2^31). if isinstance(X, sp.spmatrix | sp.sparray): - if sp.issparse(X) and X.format == "csc": - csc = X if X.format == "csc" else X.tocsc() + if X.format == "csc": return cpsp.csc_matrix( ( - cp.asarray(csc.data), - cp.asarray(csc.indices.astype(np.int32, copy=False)), - cp.asarray(csc.indptr), + cp.asarray(X.data), + cp.asarray(X.indices.astype(np.int32, copy=False)), + cp.asarray(X.indptr), ), shape=(n_rows, n_cols), ) @@ -805,7 +804,7 @@ def _wilcoxon_with_reference( if tie_correct: variance = variance[:, None] * tie_corr_arr else: - variance = cp.broadcast_to(variance[:, None], (n_test, n_total_genes)).copy() + variance = variance[:, None] diff = rank_sums - expected[:, None] if use_continuity: From e7f6274c9781f7f774992cfc535cade0eca23f19 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 24 Apr 2026 01:24:38 +0200 Subject: [PATCH 20/21] starting refactor --- .gitignore | 1 + .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 82 +- .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 403 +++- .../_cuda/wilcoxon/wilcoxon_common.cuh | 125 +- .../_cuda/wilcoxon/wilcoxon_ovo.cu | 1648 +---------------- .../_cuda/wilcoxon/wilcoxon_ovo_bindings.cuh | 264 +++ .../wilcoxon/wilcoxon_ovo_device_dense.cuh | 190 ++ .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 474 +++++ .../wilcoxon/wilcoxon_ovo_host_dense.cuh | 312 ++++ .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 837 +++++++++ .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 165 ++ .../_cuda/wilcoxon/wilcoxon_ovr.cu | 1126 +---------- .../_cuda/wilcoxon/wilcoxon_ovr_bindings.cuh | 137 ++ .../_cuda/wilcoxon/wilcoxon_ovr_dense.cuh | 317 ++++ .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 102 + .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 587 ++++++ .../tools/_rank_genes_groups/__init__.py | 105 +- .../tools/_rank_genes_groups/_core.py | 292 +-- .../tools/_rank_genes_groups/_utils.py | 160 +- .../tools/_rank_genes_groups/_wilcoxon.py | 281 ++- 20 files changed, 4481 insertions(+), 3127 deletions(-) create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_bindings.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_dense.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_dense.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_bindings.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_dense.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh diff --git a/.gitignore b/.gitignore index c0e83438..b7a8a4f6 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__/ /data/ test-data/ .vscode/ +.codex # Distribution / packaging /dist/ diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index a6dc74f9..bcd70dc7 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -423,7 +423,8 @@ __global__ void ovr_cast_and_accumulate_dense_kernel( const InT* __restrict__ block_in, float* __restrict__ block_f32_out, const int* __restrict__ group_codes, double* __restrict__ group_sums, double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, - int n_rows, int sb_cols, int n_groups) { + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { int col = blockIdx.x; if (col >= sb_cols) return; @@ -434,8 +435,8 @@ __global__ void ovr_cast_and_accumulate_dense_kernel( for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { s_sum[g] = 0.0; - s_sq[g] = 0.0; - s_nnz[g] = 0.0; + if (compute_sq_sums) s_sq[g] = 0.0; + if (compute_nnz) s_nnz[g] = 0.0; } __syncthreads(); @@ -449,55 +450,19 @@ __global__ void ovr_cast_and_accumulate_dense_kernel( int g = group_codes[r]; if (g < n_groups) { atomicAdd(&s_sum[g], v); - atomicAdd(&s_sq[g], v * v); - if (v != 0.0) atomicAdd(&s_nnz[g], 1.0); + if (compute_sq_sums) atomicAdd(&s_sq[g], v * v); + if (compute_nnz && v != 0.0) atomicAdd(&s_nnz[g], 1.0); } } __syncthreads(); for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { group_sums[(size_t)g * sb_cols + col] = s_sum[g]; - group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; - group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; - } -} - -/** - * One-shot cast-and-accumulate kernel for CSR-layout host streaming. - * - * The OVO CSR host path uploads the full CSR once; this kernel walks the - * uploaded data row-by-row, writes a float32 copy of the values, and - * accumulates per-group sum/sum-sq/nnz directly into a full-size - * (n_groups_stats, n_cols) output using global atomics. stats_codes[row] - * must be in [0, n_groups_stats) to contribute; other values (e.g. the - * sentinel for unselected cells) are skipped. - * - * Grid: (ceil(n_rows/tpb),), Block: (tpb,). - */ -template -__global__ void cast_and_accumulate_csr_kernel( - const InT* __restrict__ data_in, float* __restrict__ data_f32_out, - const int* __restrict__ indices, const int* __restrict__ indptr, - const int* __restrict__ stats_codes, double* __restrict__ group_sums, - double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, - int n_rows, int n_cols, int n_groups_stats) { - int row = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= n_rows) return; - int slot = stats_codes[row]; - int rs = indptr[row]; - int re = indptr[row + 1]; - bool accumulate = (slot >= 0 && slot < n_groups_stats); - for (int p = rs; p < re; p++) { - InT v_in = data_in[p]; - double v = (double)v_in; - data_f32_out[p] = (float)v_in; - if (accumulate) { - int c = indices[p]; - atomicAdd(&group_sums[(size_t)slot * n_cols + c], v); - atomicAdd(&group_sq_sums[(size_t)slot * n_cols + c], v * v); - if (v != 0.0) { - atomicAdd(&group_nnz[(size_t)slot * n_cols + c], 1.0); - } + if (compute_sq_sums) { + group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; + } + if (compute_nnz) { + group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; } } } @@ -514,13 +479,14 @@ __global__ void cast_and_accumulate_csr_kernel( * Block-per-column layout (grid: (sb_cols,), block: (tpb,)). * Shared memory: 3 * n_groups doubles. */ -template +template __global__ void ovr_cast_and_accumulate_sparse_kernel( const InT* __restrict__ data_in, float* __restrict__ data_f32_out, - const int* __restrict__ indices, const int* __restrict__ col_seg_offsets, + const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, const int* __restrict__ group_codes, double* __restrict__ group_sums, double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, - int sb_cols, int n_groups) { + int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { int col = blockIdx.x; if (col >= sb_cols) return; @@ -534,8 +500,8 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { s_sum[g] = 0.0; - s_sq[g] = 0.0; - s_nnz[g] = 0.0; + if (compute_sq_sums) s_sq[g] = 0.0; + if (compute_nnz) s_nnz[g] = 0.0; } __syncthreads(); @@ -543,19 +509,23 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( InT v_in = data_in[i]; double v = (double)v_in; data_f32_out[i] = (float)v_in; - int row = indices[i]; + int row = (int)indices[i]; int g = group_codes[row]; if (g < n_groups) { atomicAdd(&s_sum[g], v); - atomicAdd(&s_sq[g], v * v); - if (v != 0.0) atomicAdd(&s_nnz[g], 1.0); + if (compute_sq_sums) atomicAdd(&s_sq[g], v * v); + if (compute_nnz && v != 0.0) atomicAdd(&s_nnz[g], 1.0); } } __syncthreads(); for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { group_sums[(size_t)g * sb_cols + col] = s_sum[g]; - group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; - group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; + if (compute_sq_sums) { + group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; + } + if (compute_nnz) { + group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; + } } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh index ea7586b0..4bb38b18 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -157,7 +157,7 @@ __global__ void batched_rank_sums_presorted_kernel( const float* __restrict__ ref_sorted, const float* __restrict__ grp_sorted, const int* __restrict__ grp_offsets, double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, - int n_groups, bool compute_tie_corr) { + int n_groups, bool compute_tie_corr, int skip_n_grp_le /*= 0*/) { int col = blockIdx.x; int grp = blockIdx.y; if (col >= n_cols || grp >= n_groups) return; @@ -166,6 +166,9 @@ __global__ void batched_rank_sums_presorted_kernel( int g_end = grp_offsets[grp + 1]; int n_grp = g_end - g_start; + // Size-gated dispatch (see ovo_fused_sort_rank_kernel for the contract). + if (n_grp <= skip_n_grp_le) return; + if (n_grp == 0) { if (threadIdx.x == 0) { rank_sums[grp * n_cols + col] = 0.0; @@ -268,7 +271,7 @@ __global__ void ovo_fused_sort_rank_kernel( double* __restrict__ rank_sums, // (n_groups, n_cols) row-major double* __restrict__ tie_corr, // (n_groups, n_cols) row-major int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, - int padded_grp_size) { + int padded_grp_size, int skip_n_grp_le /*= 0*/) { int col = blockIdx.x; int grp = blockIdx.y; if (col >= n_cols || grp >= n_groups) return; @@ -277,6 +280,11 @@ __global__ void ovo_fused_sort_rank_kernel( int g_end = grp_offsets[grp + 1]; int n_grp = g_end - g_start; + // Size-gated dispatch: when co-launched with the Tier 0 warp kernel we + // skip groups it's already handling. Each group owns its own + // rank_sums row, so the two kernels' writes never alias. + if (n_grp <= skip_n_grp_le) return; + if (n_grp == 0) { if (threadIdx.x == 0) { rank_sums[grp * n_cols + col] = 0.0; @@ -390,3 +398,394 @@ __global__ void ovo_fused_sort_rank_kernel( compute_tie_correction_parallel(ref_col, n_ref, grp_smem, n_grp, warp_buf, &tie_corr[grp * n_cols + col]); } + +// ============================================================================ +// Tier 2 helper: tie contribution of the sorted reference alone. +// One block per column. The medium unsorted-rank kernel uses this as a base +// and only adds group-only/overlap deltas from the unsorted group values. +// ============================================================================ + +__global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted, + double* __restrict__ ref_tie_sums, int n_ref, + int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + const float* ref_col = ref_sorted + (long long)col * n_ref; + + double local_tie = 0.0; + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt = lo - i; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + __shared__ double warp_buf[32]; + double total = block_reduce_sum(local_tie, warp_buf); + if (threadIdx.x == 0) ref_tie_sums[col] = total; +} + +// ============================================================================ +// Tier 2 fused kernel: no-sort direct rank for medium groups. +// +// Avoids the smem bitonic sort for groups in (skip_n_grp_le, +// max_n_grp_le]. Ranks are computed from ref binary searches plus an +// in-group scan over unsorted shared values. Tie correction starts from +// ref_tie_sums[col] and adds only group-only / ref-overlap deltas. +// ============================================================================ + +__global__ void ovo_medium_unsorted_rank_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int skip_n_grp_le, int max_n_grp_le) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + if (n_grp <= skip_n_grp_le || n_grp > max_n_grp_le) return; + + extern __shared__ char smem_raw[]; + float* grp_smem = (float*)smem_raw; + double* warp_buf = (double*)(smem_raw + max_n_grp_le * sizeof(float)); + + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) + grp_smem[i] = grp_col[i]; + __syncthreads(); + + const float* ref_col = ref_sorted + (long long)col * n_ref; + double local_sum = 0.0; + double local_tie_delta = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_smem[i]; + + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + int n_lt_grp = 0; + int n_eq_grp = 0; + bool first_in_grp = true; + for (int j = 0; j < n_grp; ++j) { + float w = grp_smem[j]; + if (w < v) ++n_lt_grp; + if (w == v) { + ++n_eq_grp; + if (j < i) first_in_grp = false; + } + } + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + + if (compute_tie_corr && first_in_grp) { + double cg = (double)n_eq_grp; + double cr = (double)n_eq_ref; + double group_tie = (cg > 1.0) ? (cg * cg * cg - cg) : 0.0; + local_tie_delta += group_tie; + if (cr > 0.0) { + double combined = cr + cg; + double ref_tie = (cr > 1.0) ? (cr * cr * cr - cr) : 0.0; + local_tie_delta += combined * combined * combined - combined - + ref_tie - group_tie; + } + } + } + + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + double tie_delta = block_reduce_sum(local_tie_delta, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + double tie_sum = ref_tie_sums[col] + tie_delta; + tie_corr[grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +// ============================================================================ +// Warp-scoped tie correction for Tier 0. +// +// Sorted values live in a 32-lane register (one per lane, with unused lanes +// carrying +INF). Walks unique values via lane-step differentials and +// counts ties across the sorted ref column via binary search. All the +// sync is __syncwarp — no smem, no __syncthreads. +// ============================================================================ + +__device__ __forceinline__ double tier0_tie_sum_warp(const float* ref_col, + int n_ref, float v_lane, + int n_grp, + unsigned int active_mask) { + int lane = threadIdx.x & 31; + double local_tie = 0.0; + + // Pass 1: for each unique value in ref_col, count occurrences in ref and + // in the sorted group (held in register v_lane across 32 lanes). + for (int base = 0; base < n_ref; base += 32) { + int i = base + lane; + bool in_ref_lane = (i < n_ref); + float v = in_ref_lane ? ref_col[i] : 0.0f; + bool is_first = in_ref_lane && ((i == 0) || (v != ref_col[i - 1])); + int cnt_ref = 0; + if (is_first) { + // Count in ref: upper_bound from i+1 + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + cnt_ref = lo - i; + } + + // Count in grp: look up how many lanes hold v_lane == v. All lanes + // execute the shuffle loop; only lanes owning a unique ref value use + // the result. + int cnt_grp = 0; +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + float vi = __shfl_sync(0xffffffff, v_lane, lane_i); + if (is_first && lane_i < n_grp && vi == v) ++cnt_grp; + } + + if (is_first) { + int cnt = cnt_ref + cnt_grp; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + // Pass 2: unique values in grp that are absent from ref. + // Walk lanes 0..n_grp-1; for each lane whose v differs from prev lane's, + // binary-search ref for v. If not present, count consecutive matching + // lanes (tie block). + if (lane < n_grp) { + float v = v_lane; + float prev_lane_v = + __shfl_sync(active_mask, v_lane, (lane > 0) ? lane - 1 : 0); + float v_prev = + (lane > 0) ? prev_lane_v : __int_as_float(0xff800000); // -INF + bool first_in_grp = (lane == 0) || (v != v_prev); + bool in_ref = false; + if (first_in_grp) { + // Binary search in ref. + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + in_ref = (lo < n_ref) && (ref_col[lo] == v); + } + + // Count how many lanes ≥ this lane hold the same v. Keep the shuffle + // uniform across active lanes even though only unique, ref-absent + // group values consume the count. + int cnt = 0; +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + int src_lane = (lane_i < n_grp) ? lane_i : 0; + float vi = __shfl_sync(active_mask, v_lane, src_lane); + if (first_in_grp && !in_ref && lane_i >= lane && lane_i < n_grp && + vi == v) { + ++cnt; + } + } + if (first_in_grp && !in_ref && cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + + // Warp reduce. +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_tie += __shfl_down_sync(0xffffffff, local_tie, off); + return local_tie; // meaningful on lane 0. +} + +// ============================================================================ +// Tier 0 fused kernel: warp-per-(col, group) pair, 8 warps packed per block. +// +// Each warp independently: +// 1. Loads ≤ 32 group values into a single register (one per lane, +// padded with +INF). +// 2. Bitonic-sorts via __shfl_xor_sync — no smem, no __syncthreads. +// 3. Binary-searches into sorted ref for each lane's value and +// accumulates the rank-sum term. +// 4. Warp-shuffle reduces to lane 0 and writes rank_sums / tie_corr. +// +// 8 (col, group) pairs per block cuts block count 8× vs the block-per-pair +// Tier 1, and the lack of __syncthreads / smem sort lets each warp run +// independently at full throughput. +// +// Grid: (n_cols, ceil(n_groups / 8)), Block: 256. +// ============================================================================ + +__global__ void ovo_warp_sort_rank_kernel(const float* __restrict__ ref_sorted, + const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + double* __restrict__ rank_sums, + double* __restrict__ tie_corr, + int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr) { + constexpr int WARPS_PER_BLOCK = 8; + int warp_id = threadIdx.x >> 5; + int lane = threadIdx.x & 31; + + int col = blockIdx.x; + int grp = blockIdx.y * WARPS_PER_BLOCK + warp_id; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + // This kernel only handles groups that fit in a single warp (one value + // per lane). Larger groups are delegated to Tier 1/3 in a co-launched + // kernel; since each group owns its own row in rank_sums/tie_corr, the + // two kernels interlace into the output without conflict. + if (n_grp > TIER0_GROUP_THRESHOLD) return; + + if (n_grp == 0) { + if (lane == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + // One value per lane, pad with +INF so sort pushes them to the end. + const float POS_INF = __int_as_float(0x7f800000); + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + float x = (lane < n_grp) ? grp_col[lane] : POS_INF; + unsigned int active_mask = __ballot_sync(0xffffffff, lane < n_grp); + + // Warp-shuffle bitonic sort (ascending) — 32 elements in registers. + for (int k = 1; k <= 16; k <<= 1) { + for (int j = k; j > 0; j >>= 1) { + float y = __shfl_xor_sync(0xffffffff, x, j); + bool asc = (((lane & (k << 1)) == 0)); + bool take_min = (((lane & j) == 0) == asc); + x = take_min ? fminf(x, y) : fmaxf(x, y); + } + } + + // After sort, x[lane] holds the lane-th smallest group value (lanes + // ≥ n_grp hold +INF). Binary-search each value into the sorted ref. + const float* ref_col = ref_sorted + (long long)col * n_ref; + double local_sum = 0.0; + + if (lane < n_grp) { + float v = x; + // Lower bound in ref. + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + // Upper bound in ref. + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + // In-group counts: in the sorted warp-register x, count lanes < this + // one that hold strictly less, and lanes with equal value. + int n_lt_grp = 0; + int n_eq_grp_offset = 0; // tied lanes strictly before this one + int n_eq_grp_after = 1; // count self +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + if (lane_i >= n_grp) continue; + float vi = __shfl_sync(active_mask, v, lane_i); + if (lane_i < lane) { + if (vi < v) + ++n_lt_grp; + else if (vi == v) + ++n_eq_grp_offset; + } else if (lane_i > lane) { + if (vi == v) ++n_eq_grp_after; + } + } + int n_eq_grp_total = n_eq_grp_offset + n_eq_grp_after; + // Contribution: rank = n_lt_ref + n_lt_grp + (n_eq_ref + + // n_eq_grp_total + 1) / 2, but we sum per lane so each tie lane + // gets the same mid-rank. This matches the Tier 1 accumulation. + local_sum = (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp_total) + 1.0) / 2.0; + } + + // Warp reduce. +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_sum += __shfl_down_sync(0xffffffff, local_sum, off); + if (lane == 0) rank_sums[grp * n_cols + col] = local_sum; + + if (!compute_tie_corr) return; + + // Warp-scoped tie correction. + double tie_sum = tier0_tie_sum_warp(ref_col, n_ref, x, n_grp, active_mask); + if (lane == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + tie_corr[grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh index 5ac8ade7..c0456b3f 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh @@ -24,9 +24,26 @@ constexpr int END_BIT = 32; constexpr int UTIL_BLOCK_SIZE = 256; // Scratch slots for warp-level reduction (one slot per warp, 32 warps max). constexpr int WARP_REDUCE_BUF = 32; +// Max group size for the super-fast "warp-per-(col,group)" fused kernel +// (Tier 0). Each warp sorts and ranks one (col, group) pair entirely in +// registers via warp-shuffle bitonic sort — no smem sort buffer, no +// __syncthreads(). Blocks pack 8 warps so block launch overhead is +// amortised 8× across (col, group) work items. This path is the fast +// route for per-celltype perturbation-style workloads where most test +// groups have only a few dozen cells. +constexpr int TIER0_GROUP_THRESHOLD = 32; +// Medium-group cutoff for the unsorted direct-rank kernel. For perturbation +// workloads most groups sit below this range, where avoiding a full smem +// bitonic sort wins despite the O(n^2) in-group count. +constexpr int TIER2_GROUP_THRESHOLD = 512; // Max group size for the fused smem-sort rank kernel (Tier 1 fast path). // Beyond this, fall back to CUB segmented sort + binary-search rank kernel. constexpr int TIER1_GROUP_THRESHOLD = 2500; +// Per-stream dense slab budget (float32 items). Dynamic sub-batching sizes +// each group's column batch so that (n_g × eff_sb_cols) ≤ this. Bigger = +// fewer kernel launches; smaller = less per-stream memory. 64M items × 4B = +// 256 MB per stream dense slab + same for sorted copy ≈ 512 MB / stream. +constexpr size_t GROUP_DENSE_BUDGET_ITEMS = 64 * 1024 * 1024; // --------------------------------------------------------------------------- // RAII guard for cudaHostRegister. Unregisters on scope exit even when an @@ -37,7 +54,24 @@ struct HostRegisterGuard { HostRegisterGuard() = default; HostRegisterGuard(void* p, size_t bytes, unsigned int flags = 0) : ptr(p) { - if (ptr) cudaHostRegister(ptr, bytes, flags); + if (ptr) { + cudaError_t err = cudaHostRegister(ptr, bytes, flags); + if (err != cudaSuccess) { + // Already-registered memory is fine; anything else means the + // subsequent kernels would read garbage from an unmapped + // pointer, so surface the error immediately. + if (err == cudaErrorHostMemoryAlreadyRegistered) { + cudaGetLastError(); // clear sticky error flag + } else { + ptr = nullptr; // don't unregister in dtor + throw std::runtime_error( + std::string("cudaHostRegister failed (") + + std::to_string((size_t)bytes) + + " bytes, flags=" + std::to_string(flags) + + "): " + cudaGetErrorString(err)); + } + } + } } ~HostRegisterGuard() { if (ptr) cudaHostUnregister(ptr); @@ -89,6 +123,95 @@ __global__ void fill_linear_offsets_kernel(int* __restrict__ out, if (i <= n_segments) out[i] = i * stride; } +/** Fill per-row stats codes for a pack of K groups. + * Given pack_grp_offsets (size K+1, relative to pack start), write + * stats_codes[r] = base_slot + group_idx_of_row_r for r in [0, pack_n_rows). + * Binary search within the K+1 offsets. */ +__global__ void fill_pack_stats_codes_kernel( + const int* __restrict__ pack_grp_offsets, int* __restrict__ stats_codes, + int K, int base_slot) { + int r = blockIdx.x * blockDim.x + threadIdx.x; + int pack_n_rows = pack_grp_offsets[K]; + if (r >= pack_n_rows) return; + int lo = 0, hi = K; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (pack_grp_offsets[m + 1] <= r) + lo = m + 1; + else + hi = m; + } + stats_codes[r] = base_slot + lo; +} + +/** Rebase a slice of indptr: out[i] = indptr[col + i] - indptr[col]. + * Grid-strided: supports arbitrary `count` (no single-block thread limit). + * Templated so that 64-bit global indptrs can produce 32-bit pack-local + * indptrs (per-pack nnz always fits in int32 thanks to the memory budget). + */ +template +__global__ void rebase_indptr_kernel(const IdxIn* __restrict__ indptr, + IdxOut* __restrict__ out, int col, + int count) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < count) out[i] = (IdxOut)(indptr[col + i] - indptr[col]); +} + +/** Fused gather + cast-to-float32 + stats accumulation, reading from mapped + * pinned host memory. Block-per-row; threads in the block cooperate on the + * row's nnz. Each nnz is read from host over PCIe exactly once — no + * intermediate native-dtype GPU buffer, no second GPU pass. + * + * h_data / h_indices: device-accessible pointers into mapped pinned host + * memory (cudaHostRegisterMapped). + * d_indptr_full: full-matrix indptr on device. + * d_row_ids: rows to gather (size n_target_rows). + * d_out_indptr: pre-computed compacted indptr, size n_target_rows+1 with + * out_indptr[i+1] - out_indptr[i] equal to the source row's + * nnz. + * + * Slot dispatch: + * d_stats_codes != nullptr → slot = d_stats_codes[r]; otherwise slot = + * fixed_slot (used for the Ref phase where every row maps to the same + * slot). slot ∉ [0, n_groups_stats) skips accumulation. + */ +template +__global__ void csr_gather_cast_accumulate_mapped_kernel( + const InT* __restrict__ h_data, const IndexT* __restrict__ h_indices, + const IndptrT* __restrict__ d_indptr_full, + const int* __restrict__ d_row_ids, const int* __restrict__ d_out_indptr, + const int* __restrict__ d_stats_codes, int fixed_slot, + float* __restrict__ d_out_data_f32, int* __restrict__ d_out_indices, + double* __restrict__ group_sums, double* __restrict__ group_sq_sums, + double* __restrict__ group_nnz, int n_target_rows, int n_cols, + int n_groups_stats, bool compute_sq_sums, bool compute_nnz) { + int r = blockIdx.x; + if (r >= n_target_rows) return; + int src_row = d_row_ids[r]; + IndptrT rs = d_indptr_full[src_row]; + IndptrT re = d_indptr_full[src_row + 1]; + int row_nnz = (int)(re - rs); + int ds = d_out_indptr[r]; + int slot = (d_stats_codes != nullptr) ? d_stats_codes[r] : fixed_slot; + bool accumulate = (slot >= 0 && slot < n_groups_stats); + for (int i = threadIdx.x; i < row_nnz; i += blockDim.x) { + InT v_in = h_data[rs + i]; + int c = (int)h_indices[rs + i]; + double v = (double)v_in; + d_out_data_f32[ds + i] = (float)v_in; + d_out_indices[ds + i] = c; + if (accumulate) { + atomicAdd(&group_sums[(size_t)slot * n_cols + c], v); + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)slot * n_cols + c], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)slot * n_cols + c], 1.0); + } + } + } +} + /** Fill linear segment offsets [0, stride, 2*stride, ...] on device. * Runs on the supplied stream so it doesn't serialize multi-stream pipelines. */ diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu index 01722345..d2bb63be 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu @@ -10,1648 +10,12 @@ using namespace nb::literals; -/** - * Build segment offsets for CUB segmented sort of group data within a - * sub-batch. offset[c * n_groups + g] = c * n_all_grp + grp_offsets[g]. - * One thread per entry. - */ -__global__ void build_seg_offsets_kernel( - const int* __restrict__ grp_offsets, // (n_groups + 1,) - int* __restrict__ out, // (sb_cols * n_groups + 1,) - int n_all_grp, int n_groups, int sb_cols) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int total = sb_cols * n_groups + 1; - if (idx >= total) return; - if (idx == sb_cols * n_groups) { - out[idx] = sb_cols * n_all_grp; - } else { - int c = idx / n_groups; - int g = idx % n_groups; - out[idx] = c * n_all_grp + grp_offsets[g]; - } -} - -/** - * Extract specific rows from CSC into dense F-order, using a row lookup map. - * row_map[original_row] = output_row_index (or -1 to skip). - * One block per column, threads scatter matching nonzeros. - * Output must be pre-zeroed. - */ -__global__ void csc_extract_mapped_kernel(const float* __restrict__ data, - const int* __restrict__ indices, - const int* __restrict__ indptr, - const int* __restrict__ row_map, - float* __restrict__ out, int n_target, - int col_start) { - int col_local = blockIdx.x; - int col = col_start + col_local; - - int start = indptr[col]; - int end = indptr[col + 1]; - - for (int p = start + threadIdx.x; p < end; p += blockDim.x) { - int out_row = row_map[indices[p]]; - if (out_row >= 0) { - out[(long long)col_local * n_target + out_row] = data[p]; - } - } -} - -static size_t get_seg_sort_temp_bytes(int n_items, int n_segments) { - size_t bytes = 0; - auto* dk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys(nullptr, bytes, dk, dk, n_items, - n_segments, doff, doff + 1, 0, 32); - return bytes; -} - -/** - * Tier 1 dispatch: when the largest group fits in shared memory, a fused - * bitonic-sort + binary-search kernel handles the whole group per block. - * Otherwise we fall back to CUB segmented sort plus the pre-sorted rank - * kernel. This struct bundles the sizing knobs derived from the host-side - * group offsets so each streaming impl can drop a 15-line prep block. - */ -struct Tier1Config { - int max_grp_size = 0; - bool use_tier1 = false; - int padded_grp_size = 0; - int tier1_tpb = 0; - size_t tier1_smem = 0; -}; - -static Tier1Config make_tier1_config(const int* h_grp_offsets, int n_groups) { - Tier1Config c; - for (int g = 0; g < n_groups; g++) { - int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; - if (sz > c.max_grp_size) c.max_grp_size = sz; - } - c.use_tier1 = (c.max_grp_size <= TIER1_GROUP_THRESHOLD); - if (c.use_tier1) { - c.padded_grp_size = 1; - while (c.padded_grp_size < c.max_grp_size) c.padded_grp_size <<= 1; - c.tier1_tpb = std::min(c.padded_grp_size, MAX_THREADS_PER_BLOCK); - c.tier1_smem = (size_t)c.padded_grp_size * sizeof(float) + - WARP_REDUCE_BUF * sizeof(double); - } - return c; -} - -/** - * Streaming OVO pipeline. - * - * Takes pre-sorted reference (float32 F-order), unsorted group data (float32 - * F-order with group offsets), and produces rank_sums + tie_corr. - * - * For each sub-batch of columns: - * 1. CUB segmented sort-keys of group data (one segment per group per col) - * 2. batched_rank_sums_presorted_kernel (binary search in sorted ref) - */ -static void ovo_streaming_impl(const float* ref_sorted, const float* grp_data, - const int* grp_offsets, double* rank_sums, - double* tie_corr, int n_ref, int n_all_grp, - int n_cols, int n_groups, bool compute_tie_corr, - int sub_batch_cols) { - if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; - - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - - size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; - int max_n_seg = n_groups * sub_batch_cols; - size_t cub_temp_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys(nullptr, cub_temp_bytes, fk, fk, - (int)sub_grp_items, max_n_seg, - doff, doff + 1, 0, 32); - } - - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - - // Allocate per-stream buffers via RMM pool - RmmPool pool; - struct StreamBuf { - float* grp_sorted; - int* seg_offsets; - uint8_t* cub_temp; - double* sub_rank_sums; - double* sub_tie_corr; - }; - std::vector bufs(n_streams); - for (int s = 0; s < n_streams; s++) { - bufs[s].grp_sorted = pool.alloc(sub_grp_items); - bufs[s].seg_offsets = pool.alloc(max_n_seg + 1); - bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].sub_rank_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].sub_tie_corr = - pool.alloc((size_t)n_groups * sub_batch_cols); - } - - // Compute max individual group size for accurate thread count - std::vector h_off(n_groups + 1); - cudaMemcpy(h_off.data(), grp_offsets, (n_groups + 1) * sizeof(int), - cudaMemcpyDeviceToHost); - int max_grp_size = 0; - for (int g = 0; g < n_groups; g++) { - int sz = h_off[g + 1] - h_off[g]; - if (sz > max_grp_size) max_grp_size = sz; - } - int tpb_rank = - round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); - - int col = 0; - int batch_idx = 0; - while (col < n_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_n_seg = n_groups * sb_cols; - int sb_grp_items = n_all_grp * sb_cols; - int s = batch_idx % n_streams; - auto stream = streams[s]; - auto& buf = bufs[s]; - - // Build segment offsets on device - { - int total = sb_n_seg + 1; - int blk = (total + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - build_seg_offsets_kernel<<>>( - grp_offsets, buf.seg_offsets, n_all_grp, n_groups, sb_cols); - CUDA_CHECK_LAST_ERROR(build_seg_offsets_kernel); - } - - // Sort group data for this sub-batch - const float* grp_in = grp_data + (long long)col * n_all_grp; - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, grp_in, buf.grp_sorted, sb_grp_items, sb_n_seg, - buf.seg_offsets, buf.seg_offsets + 1, 0, 32, stream); - - // Rank sums: binary search sorted ref for each group element - const float* ref_sub = ref_sorted + (long long)col * n_ref; - dim3 grid(sb_cols, n_groups); - batched_rank_sums_presorted_kernel<<>>( - ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, - buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr); - CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); - - // Scatter sub-batch results to global output - cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - buf.sub_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - if (compute_tie_corr) { - cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), - buf.sub_tie_corr, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - } - - col += sb_cols; - batch_idx++; - } - - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in wilcoxon streaming: ") + - cudaGetErrorString(err)); - } - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); -} - -/** - * CSR-direct OVO streaming pipeline. - * - * One C++ call does everything: extract rows from CSR → sort → rank. - * Per sub-batch of columns: - * 1. Extract ref rows → dense f32 → CUB sort - * 2. Extract grp rows → dense f32 → CUB sort (segmented by group) - * 3. Binary search rank sums - * Only ~(n_ref + n_all_grp) × sub_batch × 4B on GPU at a time. - */ -static void ovo_streaming_csr_impl( - const float* csr_data, const int* csr_indices, const int* csr_indptr, - const int* ref_row_ids, const int* grp_row_ids, const int* grp_offsets, - double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, - int n_groups, bool compute_tie_corr, int sub_batch_cols) { - if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; - - // ---- Tier dispatch: read group offsets to determine max group size ---- - std::vector h_offsets(n_groups + 1); - cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), - cudaMemcpyDeviceToHost); - auto t1 = make_tier1_config(h_offsets.data(), n_groups); - int max_grp_size = t1.max_grp_size; - bool use_tier1 = t1.use_tier1; - int padded_grp_size = t1.padded_grp_size; - int tier1_tpb = t1.tier1_tpb; - size_t tier1_smem = t1.tier1_smem; - - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - - size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; - size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; - - // CUB temp for ref sort (always needed) + grp sort (Tier 3 only) - size_t cub_ref_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, - doff, doff + 1, BEGIN_BIT, END_BIT); - } - size_t cub_temp_bytes = cub_ref_bytes; - - if (!use_tier1) { - size_t cub_grp_bytes = 0; - int max_grp_seg = n_groups * sub_batch_cols; - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, - doff, doff + 1, BEGIN_BIT, END_BIT); - cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); - } - - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - - // Allocate per-stream buffers via RMM pool - RmmPool pool; - struct StreamBuf { - float* ref_dense; - float* ref_sorted; - float* grp_dense; - float* grp_sorted; - int* ref_seg_offsets; - int* grp_seg_offsets; - uint8_t* cub_temp; - double* sub_rank_sums; - double* sub_tie_corr; - }; - std::vector bufs(n_streams); - for (int s = 0; s < n_streams; s++) { - bufs[s].ref_dense = pool.alloc(sub_ref_items); - bufs[s].ref_sorted = pool.alloc(sub_ref_items); - bufs[s].grp_dense = pool.alloc(sub_grp_items); - bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); - bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].sub_rank_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].sub_tie_corr = - pool.alloc((size_t)n_groups * sub_batch_cols); - if (!use_tier1) { - bufs[s].grp_sorted = pool.alloc(sub_grp_items); - int max_seg = n_groups * sub_batch_cols; - bufs[s].grp_seg_offsets = pool.alloc(max_seg + 1); - } else { - bufs[s].grp_sorted = nullptr; - bufs[s].grp_seg_offsets = nullptr; - } - } - - int tpb_extract = round_up_to_warp(std::max(n_ref, n_all_grp)); - int tpb_rank = - round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); - - int col = 0; - int batch_idx = 0; - while (col < n_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_ref_items_actual = n_ref * sb_cols; - int sb_grp_items_actual = n_all_grp * sb_cols; - int s = batch_idx % n_streams; - auto stream = streams[s]; - auto& buf = bufs[s]; - - // ---- Extract + sort ref (always CUB) ---- - cudaMemsetAsync(buf.ref_dense, 0, sb_ref_items_actual * sizeof(float), - stream); - { - int blk = (n_ref + tpb_extract - 1) / tpb_extract; - csr_extract_dense_kernel<<>>( - csr_data, csr_indices, csr_indptr, ref_row_ids, buf.ref_dense, - n_ref, col, col + sb_cols); - CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); - } - upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); - { - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, - sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, - buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - } - - // ---- Extract grp rows ---- - cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), - stream); - { - int blk = (n_all_grp + tpb_extract - 1) / tpb_extract; - csr_extract_dense_kernel<<>>( - csr_data, csr_indices, csr_indptr, grp_row_ids, buf.grp_dense, - n_all_grp, col, col + sb_cols); - CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); - } - - if (use_tier1) { - // ---- Tier 1: fused smem sort + binary search ---- - dim3 grid(sb_cols, n_groups); - ovo_fused_sort_rank_kernel<<>>( - buf.ref_sorted, buf.grp_dense, grp_offsets, buf.sub_rank_sums, - buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr, padded_grp_size); - CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); - } else { - // ---- Tier 3: CUB segmented sort + binary search ---- - int sb_grp_seg = n_groups * sb_cols; - { - int total = sb_grp_seg + 1; - int blk = (total + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - build_seg_offsets_kernel<<>>( - grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, - sb_cols); - CUDA_CHECK_LAST_ERROR(build_seg_offsets_kernel); - } - { - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, - sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, - buf.grp_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - } - { - dim3 grid(sb_cols, n_groups); - batched_rank_sums_presorted_kernel<<>>( - buf.ref_sorted, buf.grp_sorted, grp_offsets, - buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, - sb_cols, n_groups, compute_tie_corr); - CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); - } - } - - // ---- Scatter to global output ---- - cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - buf.sub_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - if (compute_tie_corr) { - cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), - buf.sub_tie_corr, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - } - - col += sb_cols; - batch_idx++; - } - - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in wilcoxon streaming: ") + - cudaGetErrorString(err)); - } - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); -} - -/** - * CSC-direct OVO streaming pipeline. - * - * Like the CSR variant but extracts rows via a row-lookup map, avoiding - * CSC→CSR conversion. row_map_ref[row] = output index in ref block (-1 if - * not a ref row); likewise for row_map_grp. - */ -static void ovo_streaming_csc_impl( - const float* csc_data, const int* csc_indices, const int* csc_indptr, - const int* ref_row_map, const int* grp_row_map, const int* grp_offsets, - double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, - int n_groups, bool compute_tie_corr, int sub_batch_cols) { - if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; - - // ---- Tier dispatch ---- - std::vector h_offsets(n_groups + 1); - cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), - cudaMemcpyDeviceToHost); - auto t1 = make_tier1_config(h_offsets.data(), n_groups); - int max_grp_size = t1.max_grp_size; - bool use_tier1 = t1.use_tier1; - int padded_grp_size = t1.padded_grp_size; - int tier1_tpb = t1.tier1_tpb; - size_t tier1_smem = t1.tier1_smem; - - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - - size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; - size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; - - // CUB temp - size_t cub_ref_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, - doff, doff + 1, BEGIN_BIT, END_BIT); - } - size_t cub_temp_bytes = cub_ref_bytes; - if (!use_tier1) { - size_t cub_grp_bytes = 0; - int max_grp_seg = n_groups * sub_batch_cols; - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, - doff, doff + 1, BEGIN_BIT, END_BIT); - cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); - } - - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - - RmmPool pool; - struct StreamBuf { - float* ref_dense; - float* ref_sorted; - float* grp_dense; - float* grp_sorted; - int* ref_seg_offsets; - int* grp_seg_offsets; - uint8_t* cub_temp; - double* sub_rank_sums; - double* sub_tie_corr; - }; - std::vector bufs(n_streams); - for (int s = 0; s < n_streams; s++) { - bufs[s].ref_dense = pool.alloc(sub_ref_items); - bufs[s].ref_sorted = pool.alloc(sub_ref_items); - bufs[s].grp_dense = pool.alloc(sub_grp_items); - bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); - bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].sub_rank_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].sub_tie_corr = - pool.alloc((size_t)n_groups * sub_batch_cols); - if (!use_tier1) { - bufs[s].grp_sorted = pool.alloc(sub_grp_items); - int max_grp_seg = n_groups * sub_batch_cols; - bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg + 1); - } else { - bufs[s].grp_sorted = nullptr; - bufs[s].grp_seg_offsets = nullptr; - } - } - - int tpb_rank = - round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); - - int col = 0; - int batch_idx = 0; - while (col < n_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_ref_items_actual = n_ref * sb_cols; - int sb_grp_items_actual = n_all_grp * sb_cols; - int s = batch_idx % n_streams; - auto stream = streams[s]; - auto& buf = bufs[s]; - - // ---- Extract ref from CSC via row_map, then sort ---- - cudaMemsetAsync(buf.ref_dense, 0, sb_ref_items_actual * sizeof(float), - stream); - csc_extract_mapped_kernel<<>>( - csc_data, csc_indices, csc_indptr, ref_row_map, buf.ref_dense, - n_ref, col); - CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); - upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); - { - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, - sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, - buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - } - - // ---- Extract grp from CSC via row_map ---- - cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), - stream); - csc_extract_mapped_kernel<<>>( - csc_data, csc_indices, csc_indptr, grp_row_map, buf.grp_dense, - n_all_grp, col); - CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); - - if (use_tier1) { - dim3 grid(sb_cols, n_groups); - ovo_fused_sort_rank_kernel<<>>( - buf.ref_sorted, buf.grp_dense, grp_offsets, buf.sub_rank_sums, - buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr, padded_grp_size); - CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); - } else { - int sb_grp_seg = n_groups * sb_cols; - { - int total = sb_grp_seg + 1; - int blk = (total + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - build_seg_offsets_kernel<<>>( - grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, - sb_cols); - CUDA_CHECK_LAST_ERROR(build_seg_offsets_kernel); - } - { - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, - sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, - buf.grp_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - } - { - dim3 grid(sb_cols, n_groups); - batched_rank_sums_presorted_kernel<<>>( - buf.ref_sorted, buf.grp_sorted, grp_offsets, - buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, - sb_cols, n_groups, compute_tie_corr); - CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); - } - } - - // ---- Scatter to global output ---- - cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - buf.sub_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - if (compute_tie_corr) { - cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), - buf.sub_tie_corr, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - } - - col += sb_cols; - batch_idx++; - } - - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in wilcoxon streaming: ") + - cudaGetErrorString(err)); - } - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); -} - -/** - * Host-streaming CSC OVO pipeline. - * - * CSC arrays live on host. Only the sparse data for each sub-batch of - * columns is transferred to GPU. Row maps + group offsets are uploaded once. - * Results are written back to host per sub-batch. - */ -template -static void ovo_streaming_csc_host_impl( - const InT* h_data, const int* h_indices, const IndptrT* h_indptr, - const int* h_ref_row_map, const int* h_grp_row_map, - const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, - double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, - double* d_group_nnz, int n_ref, int n_all_grp, int n_rows, int n_cols, - int n_groups, int n_groups_stats, bool compute_tie_corr, - int sub_batch_cols) { - if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; - - // ---- Tier dispatch from host offsets ---- - auto t1 = make_tier1_config(h_grp_offsets, n_groups); - int max_grp_size = t1.max_grp_size; - bool use_tier1 = t1.use_tier1; - int padded_grp_size = t1.padded_grp_size; - int tier1_tpb = t1.tier1_tpb; - size_t tier1_smem = t1.tier1_smem; - - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - - size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; - size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; - - // CUB temp - size_t cub_ref_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, - doff, doff + 1, BEGIN_BIT, END_BIT); - } - size_t cub_temp_bytes = cub_ref_bytes; - if (!use_tier1) { - size_t cub_grp_bytes = 0; - int max_grp_seg = n_groups * sub_batch_cols; - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, - doff, doff + 1, BEGIN_BIT, END_BIT); - cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); - } - - // Max nnz across any sub-batch for sparse transfer buffer sizing - size_t max_nnz = 0; - for (int c = 0; c < n_cols; c += sub_batch_cols) { - int sb = std::min(sub_batch_cols, n_cols - c); - size_t nnz = (size_t)(h_indptr[c + sb] - h_indptr[c]); - if (nnz > max_nnz) max_nnz = nnz; - } - - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - - RmmPool pool; - - // GPU copies of row maps + group offsets + stats codes (uploaded once) - int* d_ref_row_map = pool.alloc(n_rows); - int* d_grp_row_map = pool.alloc(n_rows); - int* d_grp_offsets = pool.alloc(n_groups + 1); - int* d_stats_codes = pool.alloc(n_rows); - cudaMemcpy(d_ref_row_map, h_ref_row_map, n_rows * sizeof(int), - cudaMemcpyHostToDevice); - cudaMemcpy(d_grp_row_map, h_grp_row_map, n_rows * sizeof(int), - cudaMemcpyHostToDevice); - cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), - cudaMemcpyHostToDevice); - cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), - cudaMemcpyHostToDevice); - - struct StreamBuf { - InT* d_sparse_data_orig; - float* d_sparse_data_f32; - int* d_sparse_indices; - int* d_indptr; - float* ref_dense; - float* ref_sorted; - float* grp_dense; - float* grp_sorted; - int* ref_seg_offsets; - int* grp_seg_offsets; - uint8_t* cub_temp; - double* d_rank_sums; - double* d_tie_corr; - double* d_group_sums; - double* d_group_sq_sums; - double* d_group_nnz; - }; - std::vector bufs(n_streams); - for (int s = 0; s < n_streams; s++) { - bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); - bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); - bufs[s].d_sparse_indices = pool.alloc(max_nnz); - bufs[s].d_indptr = pool.alloc(sub_batch_cols + 1); - bufs[s].ref_dense = pool.alloc(sub_ref_items); - bufs[s].ref_sorted = pool.alloc(sub_ref_items); - bufs[s].grp_dense = pool.alloc(sub_grp_items); - bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); - bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].d_rank_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].d_tie_corr = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].d_group_sums = - pool.alloc((size_t)n_groups_stats * sub_batch_cols); - bufs[s].d_group_sq_sums = - pool.alloc((size_t)n_groups_stats * sub_batch_cols); - bufs[s].d_group_nnz = - pool.alloc((size_t)n_groups_stats * sub_batch_cols); - if (!use_tier1) { - bufs[s].grp_sorted = pool.alloc(sub_grp_items); - int max_grp_seg = n_groups * sub_batch_cols; - bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg + 1); - } else { - bufs[s].grp_sorted = nullptr; - bufs[s].grp_seg_offsets = nullptr; - } - } - - int tpb_rank = - round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); - size_t smem_cast = (size_t)(3 * n_groups_stats) * sizeof(double); - - // Pin only the sparse input arrays; outputs live on the device. - size_t total_nnz = (size_t)h_indptr[n_cols]; - HostRegisterGuard _pin_data(const_cast(h_data), - total_nnz * sizeof(InT)); - HostRegisterGuard _pin_indices(const_cast(h_indices), - total_nnz * sizeof(int)); - - int col = 0; - int batch_idx = 0; - while (col < n_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_ref_actual = n_ref * sb_cols; - int sb_grp_actual = n_all_grp * sb_cols; - int s = batch_idx % n_streams; - auto stream = streams[s]; - auto& buf = bufs[s]; - - // ---- H2D: sparse data for this column range (native dtype) ---- - IndptrT ptr_start = h_indptr[col]; - IndptrT ptr_end = h_indptr[col + sb_cols]; - size_t nnz = (size_t)(ptr_end - ptr_start); - cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, - nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, - nnz * sizeof(int), cudaMemcpyHostToDevice, stream); - { - std::vector h_adj(sb_cols + 1); - for (int i = 0; i <= sb_cols; i++) - h_adj[i] = (int)(h_indptr[col + i] - ptr_start); - cudaMemcpy(buf.d_indptr, h_adj.data(), (sb_cols + 1) * sizeof(int), - cudaMemcpyHostToDevice); - } - - // ---- Cast to float32 for sort + accumulate stats in float64 ---- - ovr_cast_and_accumulate_sparse_kernel - <<>>( - buf.d_sparse_data_orig, buf.d_sparse_data_f32, - buf.d_sparse_indices, buf.d_indptr, d_stats_codes, - buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, - n_groups_stats); - CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); - - // ---- Extract ref from CSC via row_map, sort ---- - cudaMemsetAsync(buf.ref_dense, 0, sb_ref_actual * sizeof(float), - stream); - csc_extract_mapped_kernel<<>>( - buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, - d_ref_row_map, buf.ref_dense, n_ref, 0); - CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); - upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); - { - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, - sb_ref_actual, sb_cols, buf.ref_seg_offsets, - buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - } - - // ---- Extract grp from CSC via row_map ---- - cudaMemsetAsync(buf.grp_dense, 0, sb_grp_actual * sizeof(float), - stream); - csc_extract_mapped_kernel<<>>( - buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, - d_grp_row_map, buf.grp_dense, n_all_grp, 0); - CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); - - // ---- Tier dispatch: sort grp + rank ---- - if (use_tier1) { - dim3 grid(sb_cols, n_groups); - ovo_fused_sort_rank_kernel<<>>( - buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, - buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr, padded_grp_size); - CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); - } else { - int sb_grp_seg = n_groups * sb_cols; - { - int total = sb_grp_seg + 1; - int blk = (total + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - build_seg_offsets_kernel<<>>( - d_grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, - sb_cols); - CUDA_CHECK_LAST_ERROR(build_seg_offsets_kernel); - } - { - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, - sb_grp_actual, sb_grp_seg, buf.grp_seg_offsets, - buf.grp_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - } - { - dim3 grid(sb_cols, n_groups); - batched_rank_sums_presorted_kernel<<>>( - buf.ref_sorted, buf.grp_sorted, d_grp_offsets, - buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, - n_groups, compute_tie_corr); - CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); - } - } - - // ---- D2D: scatter sub-batch results into caller's GPU buffers ---- - cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), - buf.d_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - if (compute_tie_corr) { - cudaMemcpy2DAsync(d_tie_corr + col, n_cols * sizeof(double), - buf.d_tie_corr, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - } - cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), - buf.d_group_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups_stats, - cudaMemcpyDeviceToDevice, stream); - cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), - buf.d_group_sq_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups_stats, - cudaMemcpyDeviceToDevice, stream); - cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), - buf.d_group_nnz, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups_stats, - cudaMemcpyDeviceToDevice, stream); - - col += sb_cols; - batch_idx++; - } - - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in wilcoxon streaming: ") + - cudaGetErrorString(err)); - } - - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); -} - -/** - * Host CSR OVO pipeline — preload reference, stream perturbations. - * - * Two-phase approach: - * Phase 1: Transfer CSR to GPU, extract ref rows for ALL columns, sort once. - * Phase 2: For each column sub-batch, extract only grp rows, sort, rank - * against the pre-sorted reference. - * - * The reference is sorted once (not per sub-batch), saving ~50% of the - * per-sub-batch extraction + sort work. - */ -template -static void ovo_streaming_csr_host_impl( - const InT* h_data, const int* h_indices, const IndptrT* h_indptr, - const int* h_ref_row_ids, const int* h_grp_row_ids, - const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, - double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, - double* d_group_nnz, int n_ref, int n_all_grp, int n_rows, int n_cols, - int n_groups, int n_groups_stats, size_t nnz, bool compute_tie_corr, - int sub_batch_cols) { - if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; - - // ---- Tier dispatch from host offsets ---- - auto t1 = make_tier1_config(h_grp_offsets, n_groups); - int max_grp_size = t1.max_grp_size; - bool use_tier1 = t1.use_tier1; - int padded_grp_size = t1.padded_grp_size; - int tier1_tpb = t1.tier1_tpb; - size_t tier1_smem = t1.tier1_smem; - - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - - size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; - - // CUB temp — sized for the larger of ref (full) or grp (sub-batch) - size_t ref_total = (size_t)n_ref * n_cols; - size_t cub_ref_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys(nullptr, cub_ref_bytes, fk, fk, - (int)ref_total, n_cols, doff, - doff + 1, BEGIN_BIT, END_BIT); - } - size_t cub_temp_bytes = cub_ref_bytes; - if (!use_tier1) { - size_t cub_grp_bytes = 0; - int max_grp_seg = n_groups * sub_batch_cols; - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, - doff, doff + 1, BEGIN_BIT, END_BIT); - cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); - } - - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - - RmmPool pool; - - // ---- Phase 1: Transfer CSR (native dtype), cast + accumulate stats ---- - InT* d_data_orig = pool.alloc(nnz); - float* d_data = pool.alloc(nnz); - int* d_indices = pool.alloc(nnz); - int* d_indptr = pool.alloc(n_rows + 1); - int* d_ref_row_ids = pool.alloc(n_ref); - int* d_grp_row_ids = pool.alloc(n_all_grp); - int* d_grp_offsets = pool.alloc(n_groups + 1); - int* d_stats_codes = pool.alloc(n_rows); - - HostRegisterGuard _pin_data(const_cast(h_data), nnz * sizeof(InT)); - HostRegisterGuard _pin_indices(const_cast(h_indices), - nnz * sizeof(int)); - cudaMemcpyAsync(d_data_orig, h_data, nnz * sizeof(InT), - cudaMemcpyHostToDevice, streams[0]); - cudaMemcpyAsync(d_indices, h_indices, nnz * sizeof(int), - cudaMemcpyHostToDevice, streams[0]); - { - std::vector h_indptr32(n_rows + 1); - for (int i = 0; i <= n_rows; i++) h_indptr32[i] = (int)h_indptr[i]; - cudaMemcpy(d_indptr, h_indptr32.data(), (n_rows + 1) * sizeof(int), - cudaMemcpyHostToDevice); - } - cudaMemcpy(d_ref_row_ids, h_ref_row_ids, n_ref * sizeof(int), - cudaMemcpyHostToDevice); - cudaMemcpy(d_grp_row_ids, h_grp_row_ids, n_all_grp * sizeof(int), - cudaMemcpyHostToDevice); - cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), - cudaMemcpyHostToDevice); - cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), - cudaMemcpyHostToDevice); - - // Zero caller's stats buffers before the atomicAdd-based kernel. - cudaMemsetAsync(d_group_sums, 0, - (size_t)n_groups_stats * n_cols * sizeof(double), - streams[0]); - cudaMemsetAsync(d_group_sq_sums, 0, - (size_t)n_groups_stats * n_cols * sizeof(double), - streams[0]); - cudaMemsetAsync(d_group_nnz, 0, - (size_t)n_groups_stats * n_cols * sizeof(double), - streams[0]); - { - int blk = (n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - cast_and_accumulate_csr_kernel - <<>>( - d_data_orig, d_data, d_indices, d_indptr, d_stats_codes, - d_group_sums, d_group_sq_sums, d_group_nnz, n_rows, n_cols, - n_groups_stats); - CUDA_CHECK_LAST_ERROR(cast_and_accumulate_csr_kernel); - } - cudaStreamSynchronize(streams[0]); - - // Extract ref for ALL columns, sort once - float* ref_dense = pool.alloc(ref_total); - float* ref_sorted = pool.alloc(ref_total); - cudaMemset(ref_dense, 0, ref_total * sizeof(float)); - { - int tpb = round_up_to_warp(n_ref); - int blk = (n_ref + tpb - 1) / tpb; - csr_extract_dense_kernel<<>>(d_data, d_indices, d_indptr, - d_ref_row_ids, ref_dense, n_ref, - 0, n_cols); - CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); - } - { - int* ref_seg = pool.alloc(n_cols + 1); - upload_linear_offsets(ref_seg, n_cols, n_ref, nullptr); - uint8_t* cub_tmp = pool.alloc(cub_ref_bytes); - size_t temp = cub_ref_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - cub_tmp, temp, ref_dense, ref_sorted, (int)ref_total, n_cols, - ref_seg, ref_seg + 1, BEGIN_BIT, END_BIT); - } - cudaDeviceSynchronize(); - - // ---- Phase 2: Stream grp sub-batches, rank against pre-sorted ref ---- - struct StreamBuf { - float* grp_dense; - float* grp_sorted; - int* grp_seg_offsets; - uint8_t* cub_temp; - double* d_rank_sums; - double* d_tie_corr; - }; - std::vector bufs(n_streams); - for (int s = 0; s < n_streams; s++) { - bufs[s].grp_dense = pool.alloc(sub_grp_items); - bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].d_rank_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].d_tie_corr = - pool.alloc((size_t)n_groups * sub_batch_cols); - if (!use_tier1) { - bufs[s].grp_sorted = pool.alloc(sub_grp_items); - int max_grp_seg = n_groups * sub_batch_cols; - bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg + 1); - } else { - bufs[s].grp_sorted = nullptr; - bufs[s].grp_seg_offsets = nullptr; - } - } - - int tpb_extract = round_up_to_warp(n_all_grp); - int tpb_rank = - round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); - - int col = 0; - int batch_idx = 0; - while (col < n_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_grp_actual = n_all_grp * sb_cols; - int s = batch_idx % n_streams; - auto stream = streams[s]; - auto& buf = bufs[s]; - - // Extract grp only (ref already sorted) - cudaMemsetAsync(buf.grp_dense, 0, sb_grp_actual * sizeof(float), - stream); - { - int blk = (n_all_grp + tpb_extract - 1) / tpb_extract; - csr_extract_dense_kernel<<>>( - d_data, d_indices, d_indptr, d_grp_row_ids, buf.grp_dense, - n_all_grp, col, col + sb_cols); - CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); - } - - // Rank against pre-sorted ref (just slice into ref_sorted) - const float* ref_sub = ref_sorted + (long long)col * n_ref; - if (use_tier1) { - dim3 grid(sb_cols, n_groups); - ovo_fused_sort_rank_kernel<<>>( - ref_sub, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, - buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr, padded_grp_size); - CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); - } else { - int sb_grp_seg = n_groups * sb_cols; - { - int total = sb_grp_seg + 1; - int blk = (total + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - build_seg_offsets_kernel<<>>( - d_grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, - sb_cols); - CUDA_CHECK_LAST_ERROR(build_seg_offsets_kernel); - } - { - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, - sb_grp_actual, sb_grp_seg, buf.grp_seg_offsets, - buf.grp_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - } - { - dim3 grid(sb_cols, n_groups); - batched_rank_sums_presorted_kernel<<>>( - ref_sub, buf.grp_sorted, d_grp_offsets, buf.d_rank_sums, - buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr); - CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); - } - } - - // D2D: scatter rank_sums / tie_corr into caller's GPU buffers - cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), - buf.d_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - if (compute_tie_corr) { - cudaMemcpy2DAsync(d_tie_corr + col, n_cols * sizeof(double), - buf.d_tie_corr, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - } - - col += sb_cols; - batch_idx++; - } - - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in wilcoxon streaming: ") + - cudaGetErrorString(err)); - } - - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); -} - -/** - * Gather specific rows from a dense F-order block into a smaller dense block. - * Grid: (n_cols,), Block: 256. - * row_ids[i] = original row index → output row i. - */ -__global__ void dense_gather_rows_kernel(const float* __restrict__ in, - const int* __restrict__ row_ids, - float* __restrict__ out, int n_rows_in, - int n_target, int n_cols) { - int col = blockIdx.x; - if (col >= n_cols) return; - const float* in_col = in + (long long)col * n_rows_in; - float* out_col = out + (long long)col * n_target; - for (int i = threadIdx.x; i < n_target; i += blockDim.x) { - out_col[i] = in_col[row_ids[i]]; - } -} - -/** - * Host-streaming dense OVO pipeline. - * - * Dense F-order float32 lives on host. Sub-batches of columns are H2D - * transferred, then ref/grp rows are gathered, sorted, and ranked. - * Results D2H per sub-batch. - */ -template -static void ovo_streaming_dense_host_impl( - const InT* h_block, const int* h_ref_row_ids, const int* h_grp_row_ids, - const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, - double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, - double* d_group_nnz, int n_ref, int n_all_grp, int n_rows, int n_cols, - int n_groups, int n_groups_stats, bool compute_tie_corr, - int sub_batch_cols) { - if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; - - // ---- Tier dispatch from host offsets ---- - auto t1 = make_tier1_config(h_grp_offsets, n_groups); - int max_grp_size = t1.max_grp_size; - bool use_tier1 = t1.use_tier1; - int padded_grp_size = t1.padded_grp_size; - int tier1_tpb = t1.tier1_tpb; - size_t tier1_smem = t1.tier1_smem; - - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - - size_t sub_dense = (size_t)n_rows * sub_batch_cols; - size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; - size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; - - // CUB temp - size_t cub_ref_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, - doff, doff + 1, BEGIN_BIT, END_BIT); - } - size_t cub_temp_bytes = cub_ref_bytes; - if (!use_tier1) { - size_t cub_grp_bytes = 0; - int max_grp_seg = n_groups * sub_batch_cols; - auto* fk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, - doff, doff + 1, BEGIN_BIT, END_BIT); - cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); - } - - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - - RmmPool pool; - - // GPU copies of row_ids + group offsets + stats codes (uploaded once) - int* d_ref_row_ids = pool.alloc(n_ref); - int* d_grp_row_ids = pool.alloc(n_all_grp); - int* d_grp_offsets = pool.alloc(n_groups + 1); - int* d_stats_codes = pool.alloc(n_rows); - cudaMemcpy(d_ref_row_ids, h_ref_row_ids, n_ref * sizeof(int), - cudaMemcpyHostToDevice); - cudaMemcpy(d_grp_row_ids, h_grp_row_ids, n_all_grp * sizeof(int), - cudaMemcpyHostToDevice); - cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), - cudaMemcpyHostToDevice); - cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), - cudaMemcpyHostToDevice); - - struct StreamBuf { - InT* d_block_orig; - float* d_block_f32; - float* ref_dense; - float* ref_sorted; - float* grp_dense; - float* grp_sorted; - int* ref_seg_offsets; - int* grp_seg_offsets; - uint8_t* cub_temp; - double* d_rank_sums; - double* d_tie_corr; - double* d_group_sums; - double* d_group_sq_sums; - double* d_group_nnz; - }; - std::vector bufs(n_streams); - for (int s = 0; s < n_streams; s++) { - bufs[s].d_block_orig = pool.alloc(sub_dense); - bufs[s].d_block_f32 = pool.alloc(sub_dense); - bufs[s].ref_dense = pool.alloc(sub_ref_items); - bufs[s].ref_sorted = pool.alloc(sub_ref_items); - bufs[s].grp_dense = pool.alloc(sub_grp_items); - bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); - bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].d_rank_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].d_tie_corr = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].d_group_sums = - pool.alloc((size_t)n_groups_stats * sub_batch_cols); - bufs[s].d_group_sq_sums = - pool.alloc((size_t)n_groups_stats * sub_batch_cols); - bufs[s].d_group_nnz = - pool.alloc((size_t)n_groups_stats * sub_batch_cols); - if (!use_tier1) { - bufs[s].grp_sorted = pool.alloc(sub_grp_items); - int max_grp_seg = n_groups * sub_batch_cols; - bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg + 1); - } else { - bufs[s].grp_sorted = nullptr; - bufs[s].grp_seg_offsets = nullptr; - } - } - - int tpb_rank = - round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); - size_t smem_cast = (size_t)(3 * n_groups_stats) * sizeof(double); - - // Pin only the host input; outputs live on the device. - HostRegisterGuard _pin_block(const_cast(h_block), - (size_t)n_rows * n_cols * sizeof(InT)); - - int col = 0; - int batch_idx = 0; - while (col < n_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_dense = n_rows * sb_cols; - int sb_ref_actual = n_ref * sb_cols; - int sb_grp_actual = n_all_grp * sb_cols; - int s = batch_idx % n_streams; - auto stream = streams[s]; - auto& buf = bufs[s]; - - // ---- H2D: dense column sub-batch (F-order, native dtype) ---- - cudaMemcpyAsync(buf.d_block_orig, h_block + (long long)col * n_rows, - sb_dense * sizeof(InT), cudaMemcpyHostToDevice, stream); - - // ---- Cast to float32 for sort + accumulate stats in float64 ---- - ovr_cast_and_accumulate_dense_kernel - <<>>( - buf.d_block_orig, buf.d_block_f32, d_stats_codes, - buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, n_rows, - sb_cols, n_groups_stats); - CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_kernel); - - // ---- Gather ref rows, sort ---- - dense_gather_rows_kernel<<>>( - buf.d_block_f32, d_ref_row_ids, buf.ref_dense, n_rows, n_ref, - sb_cols); - CUDA_CHECK_LAST_ERROR(dense_gather_rows_kernel); - upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); - { - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, - sb_ref_actual, sb_cols, buf.ref_seg_offsets, - buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - } - - // ---- Gather grp rows ---- - dense_gather_rows_kernel<<>>( - buf.d_block_f32, d_grp_row_ids, buf.grp_dense, n_rows, n_all_grp, - sb_cols); - CUDA_CHECK_LAST_ERROR(dense_gather_rows_kernel); - - // ---- Tier dispatch: sort grp + rank ---- - if (use_tier1) { - dim3 grid(sb_cols, n_groups); - ovo_fused_sort_rank_kernel<<>>( - buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, - buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, - compute_tie_corr, padded_grp_size); - CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); - } else { - int sb_grp_seg = n_groups * sb_cols; - { - int total = sb_grp_seg + 1; - int blk = (total + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - build_seg_offsets_kernel<<>>( - d_grp_offsets, buf.grp_seg_offsets, n_all_grp, n_groups, - sb_cols); - CUDA_CHECK_LAST_ERROR(build_seg_offsets_kernel); - } - { - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortKeys( - buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, - sb_grp_actual, sb_grp_seg, buf.grp_seg_offsets, - buf.grp_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - } - { - dim3 grid(sb_cols, n_groups); - batched_rank_sums_presorted_kernel<<>>( - buf.ref_sorted, buf.grp_sorted, d_grp_offsets, - buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, - n_groups, compute_tie_corr); - CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); - } - } - - // ---- D2D: scatter sub-batch results into caller's GPU buffers ---- - cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), - buf.d_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - if (compute_tie_corr) { - cudaMemcpy2DAsync(d_tie_corr + col, n_cols * sizeof(double), - buf.d_tie_corr, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - } - cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), - buf.d_group_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups_stats, - cudaMemcpyDeviceToDevice, stream); - cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), - buf.d_group_sq_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups_stats, - cudaMemcpyDeviceToDevice, stream); - cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), - buf.d_group_nnz, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups_stats, - cudaMemcpyDeviceToDevice, stream); - - col += sb_cols; - batch_idx++; - } - - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in wilcoxon streaming: ") + - cudaGetErrorString(err)); - } - - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); -} - -// ============================================================================ -// Nanobind module -// ============================================================================ - -template -void register_bindings(nb::module_& m) { - m.doc() = "CUDA kernels for Wilcoxon rank-sum test (OVO)"; - - // ---- Utility bindings (CUB sort, CSR extraction) ---- - - m.def("get_seg_sort_temp_bytes", &get_seg_sort_temp_bytes, "n_items"_a, - "n_segments"_a); - - m.def( - "segmented_sort", - [](gpu_array_c keys_in, - gpu_array_c keys_out, - gpu_array_c offsets, - gpu_array_c cub_temp, int n_items, int n_segments, - std::uintptr_t stream) { - size_t temp_bytes = cub_temp.size(); - cub::DeviceSegmentedRadixSort::SortKeys( - cub_temp.data(), temp_bytes, keys_in.data(), keys_out.data(), - n_items, n_segments, offsets.data(), offsets.data() + 1, 0, 32, - (cudaStream_t)stream); - CUDA_CHECK_LAST_ERROR(DeviceSegmentedRadixSort); - }, - "keys_in"_a, "keys_out"_a, "offsets"_a, "cub_temp"_a, nb::kw_only(), - "n_items"_a, "n_segments"_a, "stream"_a = 0); - - m.def( - "csr_extract_dense", - [](gpu_array_c data, - gpu_array_c indices, - gpu_array_c indptr, - gpu_array_c row_ids, - gpu_array_f out, int n_target, int col_start, - int col_stop, std::uintptr_t stream) { - int tpb = round_up_to_warp(n_target); - int blocks = (n_target + tpb - 1) / tpb; - csr_extract_dense_kernel<<>>( - data.data(), indices.data(), indptr.data(), row_ids.data(), - out.data(), n_target, col_start, col_stop); - CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); - }, - "data"_a, "indices"_a, "indptr"_a, "row_ids"_a, "out"_a, nb::kw_only(), - "n_target"_a, "col_start"_a, "col_stop"_a, "stream"_a = 0); - - m.def( - "csr_extract_dense_f32", - [](gpu_array_c data, - gpu_array_c indices, - gpu_array_c indptr, - gpu_array_c row_ids, - gpu_array_f out, int n_target, int col_start, - int col_stop, std::uintptr_t stream) { - int tpb = round_up_to_warp(n_target); - int blocks = (n_target + tpb - 1) / tpb; - csr_extract_dense_kernel<<>>( - data.data(), indices.data(), indptr.data(), row_ids.data(), - out.data(), n_target, col_start, col_stop); - CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); - }, - "data"_a, "indices"_a, "indptr"_a, "row_ids"_a, "out"_a, nb::kw_only(), - "n_target"_a, "col_start"_a, "col_stop"_a, "stream"_a = 0); - - // ---- Streaming pipelines ---- - - m.def( - "ovo_streaming_csr", - [](gpu_array_c csr_data, - gpu_array_c csr_indices, - gpu_array_c csr_indptr, - gpu_array_c ref_row_ids, - gpu_array_c grp_row_ids, - gpu_array_c grp_offsets, - gpu_array_c rank_sums, - gpu_array_c tie_corr, int n_ref, int n_all_grp, - int n_cols, int n_groups, bool compute_tie_corr, - int sub_batch_cols) { - ovo_streaming_csr_impl( - csr_data.data(), csr_indices.data(), csr_indptr.data(), - ref_row_ids.data(), grp_row_ids.data(), grp_offsets.data(), - rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, - n_groups, compute_tie_corr, sub_batch_cols); - }, - "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "ref_row_ids"_a, - "grp_row_ids"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, - nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, - "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); - - m.def( - "ovo_streaming_csc", - [](gpu_array_c csc_data, - gpu_array_c csc_indices, - gpu_array_c csc_indptr, - gpu_array_c ref_row_map, - gpu_array_c grp_row_map, - gpu_array_c grp_offsets, - gpu_array_c rank_sums, - gpu_array_c tie_corr, int n_ref, int n_all_grp, - int n_cols, int n_groups, bool compute_tie_corr, - int sub_batch_cols) { - ovo_streaming_csc_impl( - csc_data.data(), csc_indices.data(), csc_indptr.data(), - ref_row_map.data(), grp_row_map.data(), grp_offsets.data(), - rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, - n_groups, compute_tie_corr, sub_batch_cols); - }, - "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "ref_row_map"_a, - "grp_row_map"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, - nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, - "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); - - m.def( - "ovo_streaming", - [](gpu_array_f ref_sorted, - gpu_array_f grp_data, - gpu_array_c grp_offsets, - gpu_array_c rank_sums, - gpu_array_c tie_corr, int n_ref, int n_all_grp, - int n_cols, int n_groups, bool compute_tie_corr, - int sub_batch_cols) { - ovo_streaming_impl(ref_sorted.data(), grp_data.data(), - grp_offsets.data(), rank_sums.data(), - tie_corr.data(), n_ref, n_all_grp, n_cols, - n_groups, compute_tie_corr, sub_batch_cols); - }, - "ref_sorted"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, - "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, - "n_groups"_a, "compute_tie_corr"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); - - // ---- Host-streaming pipelines (host inputs, device outputs) ---- - -#define RSC_OVO_CSC_HOST_BINDING(NAME, InT, IndptrT) \ - m.def( \ - NAME, \ - [](host_array h_data, host_array h_indices, \ - host_array h_indptr, \ - host_array h_ref_row_map, \ - host_array h_grp_row_map, \ - host_array h_grp_offsets, \ - host_array h_stats_codes, \ - gpu_array_c d_rank_sums, \ - gpu_array_c d_tie_corr, \ - gpu_array_c d_group_sums, \ - gpu_array_c d_group_sq_sums, \ - gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ - int n_rows, int n_cols, int n_groups, int n_groups_stats, \ - bool compute_tie_corr, int sub_batch_cols) { \ - ovo_streaming_csc_host_impl( \ - h_data.data(), h_indices.data(), h_indptr.data(), \ - h_ref_row_map.data(), h_grp_row_map.data(), \ - h_grp_offsets.data(), h_stats_codes.data(), \ - d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ - d_group_sq_sums.data(), d_group_nnz.data(), n_ref, n_all_grp, \ - n_rows, n_cols, n_groups, n_groups_stats, compute_tie_corr, \ - sub_batch_cols); \ - }, \ - "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_map"_a, \ - "h_grp_row_map"_a, "h_grp_offsets"_a, "h_stats_codes"_a, \ - "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ - "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_ref"_a, \ - "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, \ - "n_groups_stats"_a, "compute_tie_corr"_a, \ - "sub_batch_cols"_a = SUB_BATCH_COLS) - - RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", float, int); - RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_i64", float, int64_t); - RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64", double, int); - RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_i64", double, int64_t); -#undef RSC_OVO_CSC_HOST_BINDING - -#define RSC_OVO_CSR_HOST_BINDING(NAME, InT, IndptrT) \ - m.def( \ - NAME, \ - [](host_array h_data, host_array h_indices, \ - host_array h_indptr, \ - host_array h_ref_row_ids, \ - host_array h_grp_row_ids, \ - host_array h_grp_offsets, \ - host_array h_stats_codes, \ - gpu_array_c d_rank_sums, \ - gpu_array_c d_tie_corr, \ - gpu_array_c d_group_sums, \ - gpu_array_c d_group_sq_sums, \ - gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ - int n_rows, int n_cols, int n_groups, int n_groups_stats, \ - size_t nnz, bool compute_tie_corr, int sub_batch_cols) { \ - ovo_streaming_csr_host_impl( \ - h_data.data(), h_indices.data(), h_indptr.data(), \ - h_ref_row_ids.data(), h_grp_row_ids.data(), \ - h_grp_offsets.data(), h_stats_codes.data(), \ - d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ - d_group_sq_sums.data(), d_group_nnz.data(), n_ref, n_all_grp, \ - n_rows, n_cols, n_groups, n_groups_stats, nnz, \ - compute_tie_corr, sub_batch_cols); \ - }, \ - "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_ids"_a, \ - "h_grp_row_ids"_a, "h_grp_offsets"_a, "h_stats_codes"_a, \ - "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ - "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_ref"_a, \ - "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, \ - "n_groups_stats"_a, "nnz"_a, "compute_tie_corr"_a, \ - "sub_batch_cols"_a = SUB_BATCH_COLS) - - RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", float, int); - RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_i64", float, int64_t); - RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64", double, int); - RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_i64", double, int64_t); -#undef RSC_OVO_CSR_HOST_BINDING - -#define RSC_OVO_DENSE_HOST_BINDING(NAME, InT) \ - m.def( \ - NAME, \ - [](host_array_2d h_block, \ - host_array h_ref_row_ids, \ - host_array h_grp_row_ids, \ - host_array h_grp_offsets, \ - host_array h_stats_codes, \ - gpu_array_c d_rank_sums, \ - gpu_array_c d_tie_corr, \ - gpu_array_c d_group_sums, \ - gpu_array_c d_group_sq_sums, \ - gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ - int n_rows, int n_cols, int n_groups, int n_groups_stats, \ - bool compute_tie_corr, int sub_batch_cols) { \ - ovo_streaming_dense_host_impl( \ - h_block.data(), h_ref_row_ids.data(), h_grp_row_ids.data(), \ - h_grp_offsets.data(), h_stats_codes.data(), \ - d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ - d_group_sq_sums.data(), d_group_nnz.data(), n_ref, n_all_grp, \ - n_rows, n_cols, n_groups, n_groups_stats, compute_tie_corr, \ - sub_batch_cols); \ - }, \ - "h_block"_a, "h_ref_row_ids"_a, "h_grp_row_ids"_a, "h_grp_offsets"_a, \ - "h_stats_codes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ - "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_ref"_a, \ - "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, \ - "n_groups_stats"_a, "compute_tie_corr"_a, \ - "sub_batch_cols"_a = SUB_BATCH_COLS) - - RSC_OVO_DENSE_HOST_BINDING("ovo_streaming_dense_host", float); - RSC_OVO_DENSE_HOST_BINDING("ovo_streaming_dense_host_f64", double); -#undef RSC_OVO_DENSE_HOST_BINDING -} +#include "wilcoxon_ovo_kernels.cuh" +#include "wilcoxon_ovo_device_dense.cuh" +#include "wilcoxon_ovo_device_sparse.cuh" +#include "wilcoxon_ovo_host_sparse.cuh" +#include "wilcoxon_ovo_host_dense.cuh" +#include "wilcoxon_ovo_bindings.cuh" NB_MODULE(_wilcoxon_ovo_cuda, m) { REGISTER_GPU_BINDINGS(register_bindings, m); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_bindings.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_bindings.cuh new file mode 100644 index 00000000..a96e2a4e --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_bindings.cuh @@ -0,0 +1,264 @@ +#pragma once + +template +void register_bindings(nb::module_& m) { + m.doc() = "CUDA kernels for Wilcoxon rank-sum test (OVO)"; + + // ---- Utility bindings (CUB sort, CSR extraction) ---- + + m.def("get_seg_sort_temp_bytes", &get_seg_sort_temp_bytes, "n_items"_a, + "n_segments"_a); + + m.def( + "segmented_sort", + [](gpu_array_c keys_in, + gpu_array_c keys_out, + gpu_array_c offsets, + gpu_array_c cub_temp, int n_items, int n_segments, + std::uintptr_t stream) { + size_t temp_bytes = cub_temp.size(); + cub::DeviceSegmentedRadixSort::SortKeys( + cub_temp.data(), temp_bytes, keys_in.data(), keys_out.data(), + n_items, n_segments, offsets.data(), offsets.data() + 1, 0, 32, + (cudaStream_t)stream); + CUDA_CHECK_LAST_ERROR(DeviceSegmentedRadixSort); + }, + "keys_in"_a, "keys_out"_a, "offsets"_a, "cub_temp"_a, nb::kw_only(), + "n_items"_a, "n_segments"_a, "stream"_a = 0); + + m.def( + "csr_extract_dense", + [](gpu_array_c data, + gpu_array_c indices, + gpu_array_c indptr, + gpu_array_c row_ids, + gpu_array_f out, int n_target, int col_start, + int col_stop, std::uintptr_t stream) { + int tpb = round_up_to_warp(n_target); + int blocks = (n_target + tpb - 1) / tpb; + csr_extract_dense_kernel<<>>( + data.data(), indices.data(), indptr.data(), row_ids.data(), + out.data(), n_target, col_start, col_stop); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + }, + "data"_a, "indices"_a, "indptr"_a, "row_ids"_a, "out"_a, nb::kw_only(), + "n_target"_a, "col_start"_a, "col_stop"_a, "stream"_a = 0); + + m.def( + "csr_extract_dense_f32", + [](gpu_array_c data, + gpu_array_c indices, + gpu_array_c indptr, + gpu_array_c row_ids, + gpu_array_f out, int n_target, int col_start, + int col_stop, std::uintptr_t stream) { + int tpb = round_up_to_warp(n_target); + int blocks = (n_target + tpb - 1) / tpb; + csr_extract_dense_kernel<<>>( + data.data(), indices.data(), indptr.data(), row_ids.data(), + out.data(), n_target, col_start, col_stop); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + }, + "data"_a, "indices"_a, "indptr"_a, "row_ids"_a, "out"_a, nb::kw_only(), + "n_target"_a, "col_start"_a, "col_stop"_a, "stream"_a = 0); + + // ---- Streaming pipelines ---- + + m.def( + "ovo_streaming_csr", + [](gpu_array_c csr_data, + gpu_array_c csr_indices, + gpu_array_c csr_indptr, + gpu_array_c ref_row_ids, + gpu_array_c grp_row_ids, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_csr_impl( + csr_data.data(), csr_indices.data(), csr_indptr.data(), + ref_row_ids.data(), grp_row_ids.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "ref_row_ids"_a, + "grp_row_ids"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovo_streaming_csc", + [](gpu_array_c csc_data, + gpu_array_c csc_indices, + gpu_array_c csc_indptr, + gpu_array_c ref_row_map, + gpu_array_c grp_row_map, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_csc_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + ref_row_map.data(), grp_row_map.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "ref_row_map"_a, + "grp_row_map"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovo_streaming", + [](gpu_array_f ref_sorted, + gpu_array_f grp_data, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_impl(ref_sorted.data(), grp_data.data(), + grp_offsets.data(), rank_sums.data(), + tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "ref_sorted"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, + "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + // ---- Host-streaming pipelines (host inputs, device outputs) ---- + +#define RSC_OVO_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_ref_row_map, \ + host_array h_grp_row_map, \ + host_array h_grp_offsets, \ + host_array h_stats_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ + int n_rows, int n_cols, int n_groups, int n_groups_stats, \ + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, \ + int sub_batch_cols) { \ + ovo_streaming_csc_host_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_ref_row_map.data(), h_grp_row_map.data(), \ + h_grp_offsets.data(), h_stats_codes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_ref, n_all_grp, \ + n_rows, n_cols, n_groups, n_groups_stats, compute_tie_corr, \ + compute_sq_sums, compute_nnz, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_map"_a, \ + "h_grp_row_map"_a, "h_grp_offsets"_a, "h_stats_codes"_a, \ + "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_ref"_a, \ + "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, \ + "n_groups_stats"_a, "compute_tie_corr"_a, "compute_sq_sums"_a = true, \ + "compute_nnz"_a = true, "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", float, int, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_i64", float, int, int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_idx64", float, int64_t, + int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_idx64_i64", float, int64_t, + int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64", double, int, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_i64", double, int, + int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_idx64", double, + int64_t, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVO_CSC_HOST_BINDING + +#define RSC_OVO_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_ref_row_ids, \ + host_array h_grp_row_ids, \ + host_array h_grp_offsets, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_full_rows, \ + int n_ref, int n_all_grp, int n_cols, int n_test, \ + int n_groups_stats, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, int sub_batch_cols) { \ + ovo_streaming_csr_host_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), n_full_rows, \ + h_ref_row_ids.data(), n_ref, h_grp_row_ids.data(), \ + h_grp_offsets.data(), n_all_grp, n_test, d_rank_sums.data(), \ + d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_cols, \ + n_groups_stats, compute_tie_corr, compute_sq_sums, \ + compute_nnz, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_ids"_a, \ + "h_grp_row_ids"_a, "h_grp_offsets"_a, "d_rank_sums"_a, "d_tie_corr"_a, \ + "d_group_sums"_a, "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), \ + "n_full_rows"_a, "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_test"_a, \ + "n_groups_stats"_a, "compute_tie_corr"_a, "compute_sq_sums"_a = true, \ + "compute_nnz"_a = true, "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", float, int, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_i64", float, int, int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_idx64", float, int64_t, + int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_idx64_i64", float, int64_t, + int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64", double, int, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_i64", double, int, + int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_idx64", double, + int64_t, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVO_CSR_HOST_BINDING + +#define RSC_OVO_DENSE_HOST_BINDING(NAME, InT) \ + m.def( \ + NAME, \ + [](host_array_2d h_block, \ + host_array h_ref_row_ids, \ + host_array h_grp_row_ids, \ + host_array h_grp_offsets, \ + host_array h_stats_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ + int n_rows, int n_cols, int n_groups, int n_groups_stats, \ + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, \ + int sub_batch_cols) { \ + ovo_streaming_dense_host_impl( \ + h_block.data(), h_ref_row_ids.data(), h_grp_row_ids.data(), \ + h_grp_offsets.data(), h_stats_codes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_ref, n_all_grp, \ + n_rows, n_cols, n_groups, n_groups_stats, compute_tie_corr, \ + compute_sq_sums, compute_nnz, sub_batch_cols); \ + }, \ + "h_block"_a, "h_ref_row_ids"_a, "h_grp_row_ids"_a, "h_grp_offsets"_a, \ + "h_stats_codes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_ref"_a, \ + "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, \ + "n_groups_stats"_a, "compute_tie_corr"_a, "compute_sq_sums"_a = true, \ + "compute_nnz"_a = true, "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_DENSE_HOST_BINDING("ovo_streaming_dense_host", float); + RSC_OVO_DENSE_HOST_BINDING("ovo_streaming_dense_host_f64", double); +#undef RSC_OVO_DENSE_HOST_BINDING +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_dense.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_dense.cuh new file mode 100644 index 00000000..261f95b4 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_dense.cuh @@ -0,0 +1,190 @@ +#pragma once + +/** + * Streaming OVO pipeline. + * + * Takes pre-sorted reference (float32 F-order), unsorted group data (float32 + * F-order with group offsets), and produces rank_sums + tie_corr. + * + * For each sub-batch of columns: + * 1. CUB segmented sort-keys of group data (one segment per group per col) + * 2. batched_rank_sums_presorted_kernel (binary search in sorted ref) + */ +static void ovo_streaming_impl(const float* ref_sorted, const float* grp_data, + const int* grp_offsets, double* rank_sums, + double* tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch ---- + std::vector h_off(n_groups + 1); + cudaMemcpy(h_off.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_off.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = + make_sort_group_ids(h_off.data(), n_groups, TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + size_t cub_temp_bytes = 0; + if (needs_tier3) { + int max_n_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_temp_bytes, fk, fk, (int)sub_grp_items, max_n_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Allocate per-stream buffers via RMM pool + RmmPool pool; + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + struct StreamBuf { + float* grp_sorted; + int* seg_offsets; + int* seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_n_seg = n_sort_groups * sub_batch_cols; + bufs[s].seg_offsets = pool.alloc(max_n_seg); + bufs[s].seg_ends = pool.alloc(max_n_seg); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].seg_offsets = nullptr; + bufs[s].seg_ends = nullptr; + bufs[s].cub_temp = nullptr; + } + bufs[s].ref_tie_sums = (t1.any_tier2 && compute_tie_corr) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_grp_items = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + const float* grp_in = grp_data + (long long)col * n_all_grp; + const float* ref_sub = ref_sorted + (long long)col * n_ref; + + int skip_le = 0; + if (t1.use_tier0) { + launch_tier0(ref_sub, grp_in, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (t1.any_tier2) { + if (compute_tie_corr) { + launch_ref_tie_sums(ref_sub, buf.ref_tie_sums, n_ref, sb_cols, + stream); + } + launch_tier2_medium(ref_sub, grp_in, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, + n_all_grp, sb_cols, n_groups, compute_tie_corr, + TIER0_GROUP_THRESHOLD, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + ref_sub, grp_in, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, t1.padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_n_seg = n_sort_groups * sb_cols; + { + int blk = (sb_n_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + grp_offsets, d_sort_group_ids, buf.seg_offsets, + buf.seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, grp_in, buf.grp_sorted, sb_grp_items, + sb_n_seg, buf.seg_offsets, buf.seg_ends, BEGIN_BIT, END_BIT, + stream); + CUDA_CHECK_LAST_ERROR(DeviceSegmentedRadixSort); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + // Scatter sub-batch results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh new file mode 100644 index 00000000..9f6d89e2 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -0,0 +1,474 @@ +#pragma once + +/** + * CSR-direct OVO streaming pipeline. + * + * One C++ call does everything: extract rows from CSR → sort → rank. + * Per sub-batch of columns: + * 1. Extract ref rows → dense f32 → CUB sort + * 2. Extract grp rows → dense f32 → CUB sort (segmented by group) + * 3. Binary search rank sums + * Only ~(n_ref + n_all_grp) × sub_batch × 4B on GPU at a time. + */ +static void ovo_streaming_csr_impl( + const float* csr_data, const int* csr_indices, const int* csr_indptr, + const int* ref_row_ids, const int* grp_row_ids, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch: read group offsets to determine max group size ---- + std::vector h_offsets(n_groups + 1); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, + TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp for ref sort (always needed) + grp sort (Tier 3 only) + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Allocate per-stream buffers via RMM pool + RmmPool pool; + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + struct StreamBuf { + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].ref_tie_sums = (t1.any_tier2 && compute_tie_corr) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_seg); + bufs[s].grp_seg_ends = pool.alloc(max_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_extract = round_up_to_warp(std::max(n_ref, n_all_grp)); + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = n_ref * sb_cols; + int sb_grp_items_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- Extract + sort ref (always CUB) ---- + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_items_actual * sizeof(float), + stream); + { + int blk = (n_ref + tpb_extract - 1) / tpb_extract; + csr_extract_dense_kernel<<>>( + csr_data, csr_indices, csr_indptr, ref_row_ids, buf.ref_dense, + n_ref, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + } + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Extract grp rows ---- + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), + stream); + { + int blk = (n_all_grp + tpb_extract - 1) / tpb_extract; + csr_extract_dense_kernel<<>>( + csr_data, csr_indices, csr_indptr, grp_row_ids, buf.grp_dense, + n_all_grp, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + } + + // Tier 0 handles groups ≤ TIER0_GROUP_THRESHOLD; Tier 1/3 handle + // the rest. Since each group owns its own rank_sums / tie_corr + // row, the two kernels' writes interlace without conflict. + int skip_le = 0; + if (t1.use_tier0) { + launch_tier0(buf.ref_sorted, buf.grp_dense, grp_offsets, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (t1.any_tier2) { + if (compute_tie_corr) { + launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, + sb_cols, stream); + } + launch_tier2_medium( + buf.ref_sorted, buf.grp_dense, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, TIER0_GROUP_THRESHOLD, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + // ---- Tier 1: fused smem sort + binary search ---- + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + // ---- Tier 3: CUB segmented sort + binary search ---- + int sb_grp_seg = n_sort_groups * sb_cols; + { + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, grp_offsets, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + // ---- Scatter to global output ---- + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * CSC-direct OVO streaming pipeline. + * + * Like the CSR variant but extracts rows via a row-lookup map, avoiding + * CSC→CSR conversion. row_map_ref[row] = output index in ref block (-1 if + * not a ref row); likewise for row_map_grp. + */ +static void ovo_streaming_csc_impl( + const float* csc_data, const int* csc_indices, const int* csc_indptr, + const int* ref_row_map, const int* grp_row_map, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch ---- + std::vector h_offsets(n_groups + 1); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, + TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + struct StreamBuf { + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].ref_tie_sums = (t1.any_tier2 && compute_tie_corr) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); + bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = n_ref * sb_cols; + int sb_grp_items_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- Extract ref from CSC via row_map, then sort ---- + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_items_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + csc_data, csc_indices, csc_indptr, ref_row_map, buf.ref_dense, + n_ref, col); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Extract grp from CSC via row_map ---- + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + csc_data, csc_indices, csc_indptr, grp_row_map, buf.grp_dense, + n_all_grp, col); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + + int skip_le = 0; + if (t1.use_tier0) { + launch_tier0(buf.ref_sorted, buf.grp_dense, grp_offsets, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (t1.any_tier2) { + if (compute_tie_corr) { + launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, + sb_cols, stream); + } + launch_tier2_medium( + buf.ref_sorted, buf.grp_dense, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, TIER0_GROUP_THRESHOLD, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = n_sort_groups * sb_cols; + { + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, grp_offsets, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + // ---- Scatter to global output ---- + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_dense.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_dense.cuh new file mode 100644 index 00000000..37d23ab4 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_dense.cuh @@ -0,0 +1,312 @@ +#pragma once + +/** + * Gather specific rows from a dense F-order block into a smaller dense block. + * Grid: (n_cols,), Block: 256. + * row_ids[i] = original row index → output row i. + */ +__global__ void dense_gather_rows_kernel(const float* __restrict__ in, + const int* __restrict__ row_ids, + float* __restrict__ out, int n_rows_in, + int n_target, int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + const float* in_col = in + (long long)col * n_rows_in; + float* out_col = out + (long long)col * n_target; + for (int i = threadIdx.x; i < n_target; i += blockDim.x) { + out_col[i] = in_col[row_ids[i]]; + } +} + +/** + * Host-streaming dense OVO pipeline. + * + * Dense F-order float32 lives on host. Sub-batches of columns are H2D + * transferred, then ref/grp rows are gathered, sorted, and ranked. + * Results D2H per sub-batch. + */ +template +static void ovo_streaming_dense_host_impl( + const InT* h_block, const int* h_ref_row_ids, const int* h_grp_row_ids, + const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_ref, int n_all_grp, int n_rows, int n_cols, + int n_groups, int n_groups_stats, bool compute_tie_corr, + bool compute_sq_sums, bool compute_nnz, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch from host offsets ---- + auto t1 = make_tier1_config(h_grp_offsets, n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = + make_sort_group_ids(h_grp_offsets, n_groups, TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_dense = (size_t)n_rows * sub_batch_cols; + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + + // GPU copies of row_ids + group offsets + stats codes (uploaded once) + int* d_ref_row_ids = pool.alloc(n_ref); + int* d_grp_row_ids = pool.alloc(n_all_grp); + int* d_grp_offsets = pool.alloc(n_groups + 1); + int* d_stats_codes = pool.alloc(n_rows); + int* d_sort_group_ids = nullptr; + cudaMemcpy(d_ref_row_ids, h_ref_row_ids, n_ref * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_ids, h_grp_row_ids, n_all_grp * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + InT* d_block_orig; + float* d_block_f32; + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_block_orig = pool.alloc(sub_dense); + bufs[s].d_block_f32 = pool.alloc(sub_dense); + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].ref_tie_sums = (t1.any_tier2 && compute_tie_corr) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups_stats * sub_batch_cols); + bufs[s].d_group_sq_sums = pool.alloc( + compute_sq_sums ? (size_t)n_groups_stats * sub_batch_cols : 1); + bufs[s].d_group_nnz = pool.alloc( + compute_nnz ? (size_t)n_groups_stats * sub_batch_cols : 1); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); + bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + size_t smem_cast = (size_t)(3 * n_groups_stats) * sizeof(double); + + // Pin only the host input; outputs live on the device. + HostRegisterGuard _pin_block(const_cast(h_block), + (size_t)n_rows * n_cols * sizeof(InT)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_dense = n_rows * sb_cols; + int sb_ref_actual = n_ref * sb_cols; + int sb_grp_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- H2D: dense column sub-batch (F-order, native dtype) ---- + cudaMemcpyAsync(buf.d_block_orig, h_block + (long long)col * n_rows, + sb_dense * sizeof(InT), cudaMemcpyHostToDevice, stream); + + // ---- Cast to float32 for sort + accumulate stats in float64 ---- + ovr_cast_and_accumulate_dense_kernel + <<>>( + buf.d_block_orig, buf.d_block_f32, d_stats_codes, + buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, n_rows, + sb_cols, n_groups_stats, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_kernel); + + // ---- Gather ref rows, sort ---- + dense_gather_rows_kernel<<>>( + buf.d_block_f32, d_ref_row_ids, buf.ref_dense, n_rows, n_ref, + sb_cols); + CUDA_CHECK_LAST_ERROR(dense_gather_rows_kernel); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Gather grp rows ---- + dense_gather_rows_kernel<<>>( + buf.d_block_f32, d_grp_row_ids, buf.grp_dense, n_rows, n_all_grp, + sb_cols); + CUDA_CHECK_LAST_ERROR(dense_gather_rows_kernel); + + // ---- Tier dispatch: sort grp + rank ---- + int skip_le = 0; + if (t1.use_tier0) { + launch_tier0(buf.ref_sorted, buf.grp_dense, d_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (t1.any_tier2) { + if (compute_tie_corr) { + launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, + sb_cols, stream); + } + launch_tier2_medium( + buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.ref_tie_sums, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, TIER0_GROUP_THRESHOLD, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = n_sort_groups * sb_cols; + { + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + d_grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, d_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + // ---- D2D: scatter sub-batch results into caller's GPU buffers ---- + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(d_tie_corr + col, n_cols * sizeof(double), + buf.d_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Nanobind module +// ============================================================================ diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh new file mode 100644 index 00000000..f2b4e7ae --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -0,0 +1,837 @@ +#pragma once + +/** + * Host-streaming CSC OVO pipeline. + * + * CSC arrays live on host. Only the sparse data for each sub-batch of + * columns is transferred to GPU. Row maps + group offsets are uploaded once. + * Results are written back to host per sub-batch. + */ +template +static void ovo_streaming_csc_host_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_ref_row_map, const int* h_grp_row_map, + const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_ref, int n_all_grp, int n_rows, int n_cols, + int n_groups, int n_groups_stats, bool compute_tie_corr, + bool compute_sq_sums, bool compute_nnz, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch from host offsets ---- + auto t1 = make_tier1_config(h_grp_offsets, n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = + make_sort_group_ids(h_grp_offsets, n_groups, TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + // Max nnz across any sub-batch for sparse transfer buffer sizing + size_t max_nnz = 0; + for (int c = 0; c < n_cols; c += sub_batch_cols) { + int sb = std::min(sub_batch_cols, n_cols - c); + size_t nnz = (size_t)(h_indptr[c + sb] - h_indptr[c]); + if (nnz > max_nnz) max_nnz = nnz; + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + + // GPU copies of row maps + group offsets + stats codes (uploaded once) + int* d_ref_row_map = pool.alloc(n_rows); + int* d_grp_row_map = pool.alloc(n_rows); + int* d_grp_offsets = pool.alloc(n_groups + 1); + int* d_stats_codes = pool.alloc(n_rows); + int* d_sort_group_ids = nullptr; + cudaMemcpy(d_ref_row_map, h_ref_row_map, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_map, h_grp_row_map, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + InT* d_sparse_data_orig; + float* d_sparse_data_f32; + IndexT* d_sparse_indices; + int* d_indptr; + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); + bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_indptr = pool.alloc(sub_batch_cols + 1); + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].ref_tie_sums = (t1.any_tier2 && compute_tie_corr) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups_stats * sub_batch_cols); + bufs[s].d_group_sq_sums = pool.alloc( + compute_sq_sums ? (size_t)n_groups_stats * sub_batch_cols : 1); + bufs[s].d_group_nnz = pool.alloc( + compute_nnz ? (size_t)n_groups_stats * sub_batch_cols : 1); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); + bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + size_t smem_cast = (size_t)(3 * n_groups_stats) * sizeof(double); + + // Pin only the sparse input arrays; outputs live on the device. + size_t total_nnz = (size_t)h_indptr[n_cols]; + HostRegisterGuard _pin_data(const_cast(h_data), + total_nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(IndexT)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_actual = n_ref * sb_cols; + int sb_grp_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- H2D: sparse data for this column range (native dtype) ---- + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + size_t nnz = (size_t)(ptr_end - ptr_start); + cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, + nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, + nnz * sizeof(IndexT), cudaMemcpyHostToDevice, stream); + { + std::vector h_adj(sb_cols + 1); + for (int i = 0; i <= sb_cols; i++) + h_adj[i] = (int)(h_indptr[col + i] - ptr_start); + cudaMemcpy(buf.d_indptr, h_adj.data(), (sb_cols + 1) * sizeof(int), + cudaMemcpyHostToDevice); + } + + // ---- Cast to float32 for sort + accumulate stats in float64 ---- + ovr_cast_and_accumulate_sparse_kernel + <<>>( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, + buf.d_sparse_indices, buf.d_indptr, d_stats_codes, + buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, + n_groups_stats, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); + + // ---- Extract ref from CSC via row_map, sort ---- + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, + d_ref_row_map, buf.ref_dense, n_ref, 0); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Extract grp from CSC via row_map ---- + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, + d_grp_row_map, buf.grp_dense, n_all_grp, 0); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + + // ---- Tier dispatch: sort grp + rank ---- + int skip_le = 0; + if (t1.use_tier0) { + launch_tier0(buf.ref_sorted, buf.grp_dense, d_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (t1.any_tier2) { + if (compute_tie_corr) { + launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, + sb_cols, stream); + } + launch_tier2_medium( + buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.ref_tie_sums, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, TIER0_GROUP_THRESHOLD, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = n_sort_groups * sb_cols; + { + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + d_grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, d_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + // ---- D2D: scatter sub-batch results into caller's GPU buffers ---- + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(d_tie_corr + col, n_cols * sizeof(double), + buf.d_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * Host CSR OVO pipeline — zero-copy mapped full-CSR with GPU-side row gather. + * + * Setup: pin the full host CSR with cudaHostRegisterMapped, upload the full + * indptr (small) + row_ids + pre-computed compacted indptrs. Each pack + * gathers only its rows over PCIe via a UVA kernel — the full matrix is never + * transferred to GPU. + * + * Phase 1 (Ref): fused gather + cast + stats over ref rows; segmented sort + * to d_ref_sorted (cached for the whole run). + * Phase 2 (per pack, round-robin across N_STREAMS): + * 1. rebase per-pack output indptr from the pre-uploaded global compacted + * indptr. + * 2. rebase per-pack group offsets + build per-row stats codes. + * 3. csr_gather_cast_accumulate_mapped_kernel — one PCIe pass, writes + * compacted f32 data + indices and accumulates per-group stats. + * 4. Per sub-batch: extract dense → sort → rank vs ref_sorted → scatter. + * + * Memory: d_ref_sorted (n_ref × n_cols × 4B) + N_STREAMS pack buffers sized + * for max_pack_rows × sb_cols (dense) and max_pack_nnz (compacted CSR). + * Full CSR stays on host (pinned-mapped). + */ +template +static void ovo_streaming_csr_host_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + int n_full_rows, const int* h_ref_row_ids, int n_ref, + const int* h_grp_row_ids, const int* h_grp_offsets, int n_all_grp, + int n_test, double* d_rank_sums, double* d_tie_corr, double* d_group_sums, + double* d_group_sq_sums, double* d_group_nnz, int n_cols, + int n_groups_stats, bool compute_tie_corr, bool compute_sq_sums, + bool compute_nnz, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_test == 0 || n_all_grp == 0) return; + + // ---- Pre-compute compacted indptrs on host (O(n_ref + n_all_grp)) ---- + // Use IndptrT for the global compacted indptr because the grp side can + // exceed 2^31 nnz on very large / dense matrices. Ref always fits in + // int32 since n_ref × n_cols ≪ 2B; keeping int32 there matches the + // downstream CUB segmented-sort temp sizing. + std::vector h_ref_indptr_compact(n_ref + 1); + h_ref_indptr_compact[0] = 0; + for (int i = 0; i < n_ref; i++) { + int r = h_ref_row_ids[i]; + int nnz_i = (int)(h_indptr[r + 1] - h_indptr[r]); + h_ref_indptr_compact[i + 1] = h_ref_indptr_compact[i] + nnz_i; + } + int ref_nnz = h_ref_indptr_compact[n_ref]; + + // grp: compacted indptr over concatenated test-group rows (IndptrT). + std::vector h_grp_indptr_compact(n_all_grp + 1); + h_grp_indptr_compact[0] = 0; + for (int i = 0; i < n_all_grp; i++) { + int r = h_grp_row_ids[i]; + IndptrT nnz_i = h_indptr[r + 1] - h_indptr[r]; + h_grp_indptr_compact[i + 1] = h_grp_indptr_compact[i] + nnz_i; + } + + // ---- Build packs (same rule as grp_impl, but uses compacted indptr) ---- + struct Pack { + int first; + int end; + int n_rows; + size_t nnz; + int sb_cols; + }; + std::vector packs; + int max_pack_rows = 0; + size_t max_pack_nnz = 0; + int max_pack_K = 0; + int max_pack_items = 0; + int max_pack_sb_cols = sub_batch_cols; + { + int target_packs = N_STREAMS; + int target_rows = (n_all_grp + target_packs - 1) / target_packs; + if (target_rows < 1) target_rows = 1; + size_t budget_cap_rows = + GROUP_DENSE_BUDGET_ITEMS / (size_t)sub_batch_cols; + if ((size_t)target_rows > budget_cap_rows) + target_rows = (int)budget_cap_rows; + + int cur_first = 0; + int cur_rows = 0; + size_t cur_nnz = 0; + for (int g = 0; g < n_test; g++) { + int n_g = h_grp_offsets[g + 1] - h_grp_offsets[g]; + size_t nnz_g = (size_t)(h_grp_indptr_compact[h_grp_offsets[g + 1]] - + h_grp_indptr_compact[h_grp_offsets[g]]); + int new_rows = cur_rows + n_g; + bool can_add = (cur_rows == 0) || (new_rows <= target_rows); + if (!can_add) { + size_t sb_size = + std::min((size_t)n_cols, + GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); + if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; + packs.push_back( + {cur_first, g, cur_rows, cur_nnz, (int)sb_size}); + cur_first = g; + cur_rows = n_g; + cur_nnz = nnz_g; + } else { + cur_rows = new_rows; + cur_nnz += nnz_g; + } + } + if (cur_rows > 0) { + size_t sb_size = std::min( + (size_t)n_cols, GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); + if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; + packs.push_back( + {cur_first, n_test, cur_rows, cur_nnz, (int)sb_size}); + } + } + for (const Pack& pk : packs) { + int K = pk.end - pk.first; + if (pk.n_rows > max_pack_rows) max_pack_rows = pk.n_rows; + if (pk.nnz > max_pack_nnz) max_pack_nnz = pk.nnz; + if (K > max_pack_K) max_pack_K = K; + int pack_items = pk.n_rows * pk.sb_cols; + if (pack_items > max_pack_items) max_pack_items = pack_items; + if (pk.sb_cols > max_pack_sb_cols) max_pack_sb_cols = pk.sb_cols; + } + int max_group_rows = max_pack_rows; + size_t max_sub_items = (size_t)max_pack_items; + if (max_pack_rows == 0) return; + + RmmPool pool; + + // Zero stats outputs. + cudaMemsetAsync(d_group_sums, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + } + + // ---- Pin full host data + indices as MAPPED (zero-copy accessible) ---- + size_t full_nnz = (size_t)h_indptr[n_full_rows]; + HostRegisterGuard _pin_data(const_cast(h_data), + full_nnz * sizeof(InT), cudaHostRegisterMapped); + HostRegisterGuard _pin_indices(const_cast(h_indices), + full_nnz * sizeof(IndexT), + cudaHostRegisterMapped); + + // Get device-accessible pointers (UVA makes these equal to host ptrs on + // Linux x86-64, but the API is the safe/portable way). + InT* d_data_zc = nullptr; + IndexT* d_indices_zc = nullptr; + { + cudaError_t e1 = cudaHostGetDevicePointer((void**)&d_data_zc, + const_cast(h_data), 0); + cudaError_t e2 = cudaHostGetDevicePointer( + (void**)&d_indices_zc, const_cast(h_indices), 0); + if (e1 != cudaSuccess || e2 != cudaSuccess) { + throw std::runtime_error( + std::string("cudaHostGetDevicePointer failed: ") + + cudaGetErrorString(e1 != cudaSuccess ? e1 : e2)); + } + } + + // ---- Upload full indptr (keep native IndptrT — can exceed int32) ---- + IndptrT* d_indptr_full = pool.alloc(n_full_rows + 1); + cudaMemcpy(d_indptr_full, h_indptr, (n_full_rows + 1) * sizeof(IndptrT), + cudaMemcpyHostToDevice); + + // ---- Upload row_ids + compacted indptrs + group boundaries ---- + int* d_ref_row_ids = pool.alloc(n_ref); + int* d_grp_row_ids = pool.alloc(n_all_grp); + IndptrT* d_grp_indptr_compact = pool.alloc(n_all_grp + 1); + int* d_grp_offsets_full = pool.alloc(n_test + 1); + cudaMemcpy(d_ref_row_ids, h_ref_row_ids, n_ref * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_ids, h_grp_row_ids, n_all_grp * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_indptr_compact, h_grp_indptr_compact.data(), + (n_all_grp + 1) * sizeof(IndptrT), cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets_full, h_grp_offsets, (n_test + 1) * sizeof(int), + cudaMemcpyHostToDevice); + + // ---- Phase 1: Ref setup (scoped scratch, ref_sorted persists) ---- + float* d_ref_sorted = pool.alloc((size_t)n_ref * n_cols); + { + rmm::device_buffer ref_data_f32_buf(ref_nnz * sizeof(float), + rmm::cuda_stream_default, pool.mr); + rmm::device_buffer ref_indices_buf(ref_nnz * sizeof(int), + rmm::cuda_stream_default, pool.mr); + rmm::device_buffer ref_indptr_buf((n_ref + 1) * sizeof(int), + rmm::cuda_stream_default, pool.mr); + rmm::device_buffer ref_dense_buf((size_t)n_ref * n_cols * sizeof(float), + rmm::cuda_stream_default, pool.mr); + rmm::device_buffer ref_seg_buf((n_cols + 1) * sizeof(int), + rmm::cuda_stream_default, pool.mr); + + float* d_ref_data_f32 = (float*)ref_data_f32_buf.data(); + int* d_ref_indices = (int*)ref_indices_buf.data(); + int* d_ref_indptr = (int*)ref_indptr_buf.data(); + float* d_ref_dense = (float*)ref_dense_buf.data(); + int* d_ref_seg = (int*)ref_seg_buf.data(); + + // Upload ref compacted indptr + cudaMemcpy(d_ref_indptr, h_ref_indptr_compact.data(), + (n_ref + 1) * sizeof(int), cudaMemcpyHostToDevice); + + // Fused gather + cast + stats for ref (fixed slot = n_test). One + // pass over PCIe, no intermediate native-dtype GPU buffer. + if (n_ref > 0 && ref_nnz > 0) { + csr_gather_cast_accumulate_mapped_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, d_ref_row_ids, + d_ref_indptr, /*d_stats_codes=*/nullptr, + /*fixed_slot=*/n_test, d_ref_data_f32, d_ref_indices, + d_group_sums, d_group_sq_sums, d_group_nnz, n_ref, n_cols, + n_groups_stats, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); + } + + // Extract ref dense (F-order) — identity row-ids + cudaMemsetAsync(d_ref_dense, 0, (size_t)n_ref * n_cols * sizeof(float)); + { + rmm::device_buffer ref_row_ids_buf( + n_ref * sizeof(int), rmm::cuda_stream_default, pool.mr); + int* d_ref_identity = (int*)ref_row_ids_buf.data(); + fill_linear_offsets_kernel<<<(n_ref + UTIL_BLOCK_SIZE - 1) / + UTIL_BLOCK_SIZE, + UTIL_BLOCK_SIZE>>>(d_ref_identity, + n_ref - 1, 1); + CUDA_CHECK_LAST_ERROR(fill_linear_offsets_kernel); + int tpb = round_up_to_warp(n_ref); + int blk = (n_ref + tpb - 1) / tpb; + csr_extract_dense_kernel + <<>>(d_ref_data_f32, d_ref_indices, d_ref_indptr, + d_ref_identity, d_ref_dense, n_ref, 0, n_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + } + + // Segmented sort ref_dense by column → ref_sorted + size_t ref_cub_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, ref_cub_bytes, fk, fk, (int)((size_t)n_ref * n_cols), + n_cols, doff, doff + 1, BEGIN_BIT, END_BIT); + } + rmm::device_buffer cub_temp_buf(ref_cub_bytes, rmm::cuda_stream_default, + pool.mr); + upload_linear_offsets(d_ref_seg, n_cols, n_ref, 0); + size_t temp = ref_cub_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + cub_temp_buf.data(), temp, d_ref_dense, d_ref_sorted, + (int)((size_t)n_ref * n_cols), n_cols, d_ref_seg, d_ref_seg + 1, + BEGIN_BIT, END_BIT); + cudaDeviceSynchronize(); + } // ref scratch drops here + + // Identity row-ids [0..max_group_rows-1] — shared across all packs. + int* d_identity_row_ids = pool.alloc(max_group_rows); + fill_linear_offsets_kernel<<<(max_group_rows + UTIL_BLOCK_SIZE - 1) / + UTIL_BLOCK_SIZE, + UTIL_BLOCK_SIZE>>>(d_identity_row_ids, + max_group_rows - 1, 1); + CUDA_CHECK_LAST_ERROR(fill_linear_offsets_kernel); + + // ---- Phase 2: Per-pack streaming ---- + auto t1 = make_tier1_config(h_grp_offsets, n_test); + bool may_need_cub = (t1.max_grp_size > TIER1_GROUP_THRESHOLD); + + constexpr int MAX_GROUP_STREAMS = 4; + int n_streams = MAX_GROUP_STREAMS; + if (n_test < n_streams) n_streams = n_test; + if (n_streams < 1) n_streams = 1; + if ((int)packs.size() < n_streams) n_streams = (int)packs.size(); + if (n_streams < 1) n_streams = 1; + + size_t cub_grp_bytes = 0; + if (may_need_cub && max_sub_items > 0) { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + int max_segments = max_pack_K * max_pack_sb_cols; + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)max_sub_items, max_segments, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + struct StreamBuf { + float* d_grp_data_f32; + int* d_grp_indices; + int* d_grp_indptr; + int* d_pack_grp_offsets; + int* d_pack_stats_codes; + float* d_grp_dense; + float* d_grp_sorted; + double* d_ref_tie_sums; + int* d_sort_group_ids; + int* d_grp_seg_offsets; + int* d_grp_seg_ends; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + }; + std::vector bufs(n_streams); + int max_pack_kernel_seg = max_pack_K * max_pack_sb_cols; + for (int s = 0; s < n_streams; s++) { + bufs[s].d_grp_data_f32 = pool.alloc(max_pack_nnz); + bufs[s].d_grp_indices = pool.alloc(max_pack_nnz); + bufs[s].d_grp_indptr = pool.alloc(max_pack_rows + 1); + bufs[s].d_pack_grp_offsets = pool.alloc(max_pack_K + 1); + bufs[s].d_pack_stats_codes = pool.alloc(max_pack_rows); + bufs[s].d_grp_dense = pool.alloc(max_sub_items); + bufs[s].d_ref_tie_sums = pool.alloc(max_pack_sb_cols); + bufs[s].d_rank_sums = + pool.alloc((size_t)max_pack_K * max_pack_sb_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)max_pack_K * max_pack_sb_cols); + if (may_need_cub) { + bufs[s].d_grp_sorted = pool.alloc(max_sub_items); + bufs[s].d_sort_group_ids = pool.alloc(max_pack_K); + bufs[s].d_grp_seg_offsets = pool.alloc(max_pack_kernel_seg); + bufs[s].d_grp_seg_ends = pool.alloc(max_pack_kernel_seg); + bufs[s].cub_temp = pool.alloc(cub_grp_bytes); + } else { + bufs[s].d_grp_sorted = nullptr; + bufs[s].d_sort_group_ids = nullptr; + bufs[s].d_grp_seg_offsets = nullptr; + bufs[s].d_grp_seg_ends = nullptr; + bufs[s].cub_temp = nullptr; + } + } + + cudaDeviceSynchronize(); // ensure Phase 1 done before Phase 2 streams + + for (int p = 0; p < (int)packs.size(); p++) { + const Pack& pack = packs[p]; + int K = pack.end - pack.first; + if (K == 0 || pack.n_rows == 0) continue; + Tier1Config pack_t1 = make_tier1_config(h_grp_offsets + pack.first, K); + int pack_tpb_rank = round_up_to_warp( + std::min(pack_t1.max_grp_size, MAX_THREADS_PER_BLOCK)); + bool pack_has_above_t2 = pack_t1.max_grp_size > TIER2_GROUP_THRESHOLD; + int pack_tier3_skip_le = + pack_has_above_t2 ? TIER2_GROUP_THRESHOLD : TIER0_GROUP_THRESHOLD; + std::vector h_sort_group_ids; + int pack_n_sort_groups = K; + if (pack_t1.any_above_t0 && !pack_t1.use_tier1) { + h_sort_group_ids = make_sort_group_ids(h_grp_offsets + pack.first, + K, pack_tier3_skip_le); + pack_n_sort_groups = (int)h_sort_group_ids.size(); + } + + int s = p % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + + if (pack_t1.any_above_t0 && !pack_t1.use_tier1) { + cudaMemcpyAsync(buf.d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice, stream); + } + + int row_start = h_grp_offsets[pack.first]; + int pack_rows = pack.n_rows; + int pack_sb = pack.sb_cols; + + // Rebase pack's output indptr from pre-uploaded global compacted indptr + // (IndptrT → int32: pack nnz is bounded by GROUP_DENSE_BUDGET so fits). + { + int count = pack_rows + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel + <<>>( + d_grp_indptr_compact, buf.d_grp_indptr, row_start, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Build per-pack group offsets on GPU (on this stream) — needed to + // compute stats codes before the fused gather kernel can run. + { + int count = K + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel<<>>( + d_grp_offsets_full, buf.d_pack_grp_offsets, pack.first, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Fill per-row stats codes for this pack + { + int blk = (pack_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + fill_pack_stats_codes_kernel<<>>( + buf.d_pack_grp_offsets, buf.d_pack_stats_codes, K, pack.first); + CUDA_CHECK_LAST_ERROR(fill_pack_stats_codes_kernel); + } + + // Fused gather + cast + stats for the pack. One pass over PCIe + // (reads mapped host via UVA), no intermediate native-dtype GPU + // buffer, writes f32 + indices + atomics. + if (pack.nnz > 0) { + csr_gather_cast_accumulate_mapped_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, + d_grp_row_ids + row_start, buf.d_grp_indptr, + buf.d_pack_stats_codes, /*fixed_slot=*/-1, + buf.d_grp_data_f32, buf.d_grp_indices, d_group_sums, + d_group_sq_sums, d_group_nnz, pack_rows, n_cols, + n_groups_stats, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); + } + + // Per col sub-batch + int col = 0; + while (col < n_cols) { + int sb_cols = std::min(pack_sb, n_cols - col); + int sb_items = pack_rows * sb_cols; + + cudaMemsetAsync(buf.d_grp_dense, 0, sb_items * sizeof(float), + stream); + { + int tpb = round_up_to_warp(pack_rows); + int blk = (pack_rows + tpb - 1) / tpb; + csr_extract_dense_kernel<<>>( + buf.d_grp_data_f32, buf.d_grp_indices, buf.d_grp_indptr, + d_identity_row_ids, buf.d_grp_dense, pack_rows, col, + col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + } + + const float* ref_sub = d_ref_sorted + (size_t)col * n_ref; + + int skip_le = 0; + if (pack_t1.use_tier0) { + launch_tier0(ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, pack_rows, + sb_cols, K, compute_tie_corr, stream); + if (pack_t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (pack_t1.any_above_t0) { + if (compute_tie_corr) { + launch_ref_tie_sums(ref_sub, buf.d_ref_tie_sums, n_ref, + sb_cols, stream); + } + launch_tier2_medium( + ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, n_ref, + pack_rows, sb_cols, K, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = + pack_has_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (pack_has_above_t2 && pack_t1.use_tier1) { + dim3 grid(sb_cols, K); + ovo_fused_sort_rank_kernel<<>>( + ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, pack_rows, sb_cols, + K, compute_tie_corr, pack_t1.padded_grp_size, + upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (pack_has_above_t2) { + int n_seg = pack_n_sort_groups * sb_cols; + { + int blk = (n_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<< + blk, UTIL_BLOCK_SIZE, 0, stream>>>( + buf.d_pack_grp_offsets, buf.d_sort_group_ids, + buf.d_grp_seg_offsets, buf.d_grp_seg_ends, pack_rows, + pack_n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR( + build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_grp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.d_grp_dense, buf.d_grp_sorted, + sb_items, n_seg, buf.d_grp_seg_offsets, + buf.d_grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + dim3 grid(sb_cols, K); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.d_grp_sorted, buf.d_pack_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, pack_rows, sb_cols, + K, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + + cudaMemcpy2DAsync(d_rank_sums + (size_t)pack.first * n_cols + col, + n_cols * sizeof(double), buf.d_rank_sums, + sb_cols * sizeof(double), + sb_cols * sizeof(double), K, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync( + d_tie_corr + (size_t)pack.first * n_cols + col, + n_cols * sizeof(double), buf.d_tie_corr, + sb_cols * sizeof(double), sb_cols * sizeof(double), K, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + } + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in ovo csr host streaming: ") + + cudaGetErrorString(err)); + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh new file mode 100644 index 00000000..d508b9c2 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -0,0 +1,165 @@ +#pragma once + +/** + * Build CUB segmented-sort ranges only for groups that Tier 3 will rank. + * Group ids are relative to grp_offsets, and ranges still point into the + * original dense group layout so the presorted rank kernel can read from the + * normal per-group positions. + */ +__global__ void build_tier3_seg_begin_end_offsets_kernel( + const int* __restrict__ grp_offsets, const int* __restrict__ group_ids, + int* __restrict__ begins, int* __restrict__ ends, int n_all_grp, + int n_sort_groups, int sb_cols) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = sb_cols * n_sort_groups; + if (idx >= total) return; + + int c = idx / n_sort_groups; + int local = idx % n_sort_groups; + int g = group_ids[local]; + int base = c * n_all_grp; + begins[idx] = base + grp_offsets[g]; + ends[idx] = base + grp_offsets[g + 1]; +} + +/** + * Extract specific rows from CSC into dense F-order, using a row lookup map. + * row_map[original_row] = output_row_index (or -1 to skip). + * One block per column, threads scatter matching nonzeros. + * Output must be pre-zeroed. + */ +template +__global__ void csc_extract_mapped_kernel(const float* __restrict__ data, + const IndexT* __restrict__ indices, + const int* __restrict__ indptr, + const int* __restrict__ row_map, + float* __restrict__ out, int n_target, + int col_start) { + int col_local = blockIdx.x; + int col = col_start + col_local; + + int start = indptr[col]; + int end = indptr[col + 1]; + + for (int p = start + threadIdx.x; p < end; p += blockDim.x) { + int out_row = row_map[(int)indices[p]]; + if (out_row >= 0) { + out[(long long)col_local * n_target + out_row] = data[p]; + } + } +} + +static size_t get_seg_sort_temp_bytes(int n_items, int n_segments) { + size_t bytes = 0; + auto* dk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys(nullptr, bytes, dk, dk, n_items, + n_segments, doff, doff + 1, 0, 32); + return bytes; +} + +/** + * Tier 1 dispatch: when the largest group fits in shared memory, a fused + * bitonic-sort + binary-search kernel handles the whole group per block. + * Otherwise we fall back to CUB segmented sort plus the pre-sorted rank + * kernel. This struct bundles the sizing knobs derived from the host-side + * group offsets so each streaming impl can drop a 15-line prep block. + */ +struct Tier1Config { + int max_grp_size = 0; + int min_grp_size = 0; + bool use_tier0 = + false; // any group fits in one warp (≤ TIER0_GROUP_THRESHOLD) + bool use_tier1 = + false; // any group needs > tier0 but fits in tier1 smem sort + bool any_above_t0 = + false; // at least one group exceeds TIER0_GROUP_THRESHOLD + bool any_tier2 = false; // any group needs Tier 2: (T0, T2] + bool any_above_t2 = + false; // at least one group exceeds TIER2_GROUP_THRESHOLD + int padded_grp_size = 0; + int tier1_tpb = 0; + size_t tier1_smem = 0; +}; + +static Tier1Config make_tier1_config(const int* h_grp_offsets, int n_groups) { + Tier1Config c; + c.min_grp_size = INT_MAX; + for (int g = 0; g < n_groups; g++) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (sz > c.max_grp_size) c.max_grp_size = sz; + if (sz < c.min_grp_size) c.min_grp_size = sz; + if (sz > TIER0_GROUP_THRESHOLD && sz <= TIER2_GROUP_THRESHOLD) { + c.any_tier2 = true; + } + if (sz > TIER2_GROUP_THRESHOLD) c.any_above_t2 = true; + } + if (n_groups == 0) c.min_grp_size = 0; + + // use_tier0: Tier 0 kernel is worth running (at least one group small + // enough to benefit from the warp path). + c.use_tier0 = (c.min_grp_size <= TIER0_GROUP_THRESHOLD); + // any_above_t0: at least one group needs a non-Tier-0 kernel. + c.any_above_t0 = (c.max_grp_size > TIER0_GROUP_THRESHOLD); + // use_tier1: the fused smem-sort fast path (for groups > T0 but ≤ T1). + c.use_tier1 = c.any_above_t0 && (c.max_grp_size <= TIER1_GROUP_THRESHOLD); + if (c.use_tier1) { + c.padded_grp_size = 1; + while (c.padded_grp_size < c.max_grp_size) c.padded_grp_size <<= 1; + c.tier1_tpb = std::min(c.padded_grp_size, MAX_THREADS_PER_BLOCK); + c.tier1_smem = (size_t)c.padded_grp_size * sizeof(float) + + WARP_REDUCE_BUF * sizeof(double); + } + return c; +} + +static std::vector make_sort_group_ids(const int* h_grp_offsets, + int n_groups, int skip_n_grp_le) { + std::vector ids; + ids.reserve(n_groups); + for (int g = 0; g < n_groups; ++g) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (skip_n_grp_le > 0 && sz <= skip_n_grp_le) continue; + ids.push_back(g); + } + return ids; +} + +// Tier 0 kernel launcher: 8 warps × 32 threads per block, one (col, group) +// pair per warp. grid.y covers ceil(K/8) pair rows. +static inline void launch_tier0(const float* ref_sorted, const float* grp_dense, + const int* grp_offsets, double* rank_sums, + double* tie_corr, int n_ref, int n_all_grp, + int sb_cols, int K, bool compute_tie_corr, + cudaStream_t stream) { + constexpr int WARPS_PER_BLOCK = 8; + dim3 grid(sb_cols, (K + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK); + ovo_warp_sort_rank_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, rank_sums, tie_corr, n_ref, + n_all_grp, sb_cols, K, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(ovo_warp_sort_rank_kernel); +} + +static inline void launch_ref_tie_sums(const float* ref_sorted, + double* ref_tie_sums, int n_ref, + int sb_cols, cudaStream_t stream) { + ref_tie_sum_kernel<<>>( + ref_sorted, ref_tie_sums, n_ref, sb_cols); + CUDA_CHECK_LAST_ERROR(ref_tie_sum_kernel); +} + +static inline void launch_tier2_medium( + const float* ref_sorted, const float* grp_dense, const int* grp_offsets, + const double* ref_tie_sums, double* rank_sums, double* tie_corr, int n_ref, + int n_all_grp, int sb_cols, int K, bool compute_tie_corr, int skip_n_grp_le, + cudaStream_t stream) { + constexpr int tpb = 256; + size_t smem = (size_t)TIER2_GROUP_THRESHOLD * sizeof(float) + + WARP_REDUCE_BUF * sizeof(double); + dim3 grid(sb_cols, K); + ovo_medium_unsorted_rank_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, + n_ref, n_all_grp, sb_cols, K, compute_tie_corr, skip_n_grp_le, + TIER2_GROUP_THRESHOLD); + CUDA_CHECK_LAST_ERROR(ovo_medium_unsorted_rank_kernel); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu index 6c261c15..e5661126 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu @@ -9,1128 +9,10 @@ using namespace nb::literals; -/** Rebase a slice of indptr: out[i] = indptr[col + i] - indptr[col]. - * Grid-strided: supports arbitrary `count` (no single-block thread limit). */ -__global__ void rebase_indptr_kernel(const int* __restrict__ indptr, - int* __restrict__ out, int col, - int count) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < count) out[i] = indptr[col + i] - indptr[col]; -} - -/** Subtract a constant from an int array in-place. */ -__global__ void subtract_scalar_kernel(int* __restrict__ data, int base, - int count) { - int i = threadIdx.x; - if (i < count) data[i] -= base; -} - -/** Count nonzeros per column from CSR. One thread per row. */ -__global__ void csr_col_histogram_kernel(const int* __restrict__ indices, - const int* __restrict__ indptr, - int* __restrict__ col_counts, - int n_rows, int n_cols) { - int row = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= n_rows) return; - int rs = indptr[row]; - int re = indptr[row + 1]; - for (int p = rs; p < re; ++p) { - int c = indices[p]; - if (c < n_cols) atomicAdd(&col_counts[c], 1); - } -} - -/** - * Scatter CSR nonzeros into CSC layout for columns [col_start, col_stop). - * write_pos[c - col_start] must be initialized to the prefix-sum offset - * for column c. Each thread atomically claims a unique destination slot. - */ -__global__ void csr_scatter_to_csc_kernel( - const float* __restrict__ data, const int* __restrict__ indices, - const int* __restrict__ indptr, int* __restrict__ write_pos, - float* __restrict__ csc_vals, int* __restrict__ csc_row_idx, int n_rows, - int col_start, int col_stop) { - int row = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= n_rows) return; - int rs = indptr[row]; - int re = indptr[row + 1]; - // Binary search for col_start (overflow-safe midpoint) - int lo = rs, hi = re; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (indices[m] < col_start) - lo = m + 1; - else - hi = m; - } - for (int p = lo; p < re; ++p) { - int c = indices[p]; - if (c >= col_stop) break; - int dest = atomicAdd(&write_pos[c - col_start], 1); - csc_vals[dest] = data[p]; - csc_row_idx[dest] = row; - } -} - -/** - * Decide whether to use shared or global memory for OVR rank accumulators. - * Returns the smem size to request and sets use_gmem accordingly. - */ -static int query_max_smem_per_block() { - static int cached = -1; - if (cached < 0) { - int device; - cudaGetDevice(&device); - cudaDeviceGetAttribute(&cached, cudaDevAttrMaxSharedMemoryPerBlock, - device); - } - return cached; -} - -static size_t ovr_smem_config(int n_groups, bool& use_gmem) { - size_t need = (size_t)(n_groups + 32) * sizeof(double); - if ((int)need <= query_max_smem_per_block()) { - use_gmem = false; - return need; - } - // Fall back to global memory accumulators; only need warp buf in smem - use_gmem = true; - return 32 * sizeof(double); -} - -/** - * Decide smem-vs-gmem for the sparse OVR rank kernel. Two accumulator - * arrays (grp_sums + grp_nz_count) of size n_groups each plus warp buf. - */ -static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { - size_t need = (size_t)(2 * n_groups + 32) * sizeof(double); - if ((int)need <= query_max_smem_per_block()) { - use_gmem = false; - return need; - } - use_gmem = true; - return 32 * sizeof(double); -} - -/** - * Fill sort values with row indices [0,1,...,n_rows-1] per column. - * Grid: (n_cols,), block: 256 threads. - */ -__global__ void fill_row_indices_kernel(int* __restrict__ vals, int n_rows, - int n_cols) { - int col = blockIdx.x; - if (col >= n_cols) return; - int* out = vals + (long long)col * n_rows; - for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { - out[i] = i; - } -} - -/** - * Streaming OVR pipeline. - * - * Takes a dense F-order float32 block (n_rows, n_cols) + int32 group_codes, - * splits columns into sub-batches across multiple CUDA streams, and for each: - * 1. CUB SortPairs (float32 keys + int32 row indices) - * 2. Fused rank_sums_from_sorted_kernel - * - * Output: rank_sums (n_groups, n_cols) + tie_corr (n_cols), both float64. - */ -static void ovr_streaming_impl(const float* block, const int* group_codes, - double* rank_sums, double* tie_corr, int n_rows, - int n_cols, int n_groups, bool compute_tie_corr, - int sub_batch_cols) { - if (n_rows == 0 || n_cols == 0) return; - - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - - size_t sub_items = (size_t)n_rows * sub_batch_cols; - size_t cub_temp_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* iv = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, - sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); - } - - // Create streams - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - - // Allocate per-stream buffers via RMM pool - RmmPool pool; - struct StreamBuf { - float* keys_out; - int* vals_in; - int* vals_out; - int* seg_offsets; - uint8_t* cub_temp; - double* sub_rank_sums; - double* sub_tie_corr; - }; - std::vector bufs(n_streams); - for (int s = 0; s < n_streams; s++) { - bufs[s].keys_out = pool.alloc(sub_items); - bufs[s].vals_in = pool.alloc(sub_items); - bufs[s].vals_out = pool.alloc(sub_items); - bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); - bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].sub_rank_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); - } - - int tpb_rank = round_up_to_warp(n_rows); - bool use_gmem = false; - size_t smem_rank = ovr_smem_config(n_groups, use_gmem); - - // Process sub-batches round-robin across streams - int col = 0; - int batch_idx = 0; - while (col < n_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_items = n_rows * sb_cols; - int s = batch_idx % n_streams; - auto stream = streams[s]; - auto& buf = bufs[s]; - - // Fill segment offsets + row indices - upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); - fill_row_indices_kernel<<>>( - buf.vals_in, n_rows, sb_cols); - CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); - - // Sort: keys = block columns [col, col+sb_cols), already F-order - const float* keys_in = block + (long long)col * n_rows; - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, keys_in, buf.keys_out, buf.vals_in, - buf.vals_out, sb_items, sb_cols, buf.seg_offsets, - buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - - // Fused rank sums into sub-batch buffer - if (use_gmem) { - cudaMemsetAsync(buf.sub_rank_sums, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - } - rank_sums_from_sorted_kernel<<>>( - buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, - buf.sub_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, - use_gmem); - CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); - - // Copy sub-batch results to global output (row-major scatter) - // rank_sums is (n_groups, n_cols) row-major: group g, col c → - // [g*n_cols+c] sub output is (n_groups, sb_cols): group g, local col lc - // → [g*sb_cols+lc] - cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - buf.sub_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - if (compute_tie_corr) { - cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, - sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, - stream); - } - - col += sb_cols; - batch_idx++; - } - - // Sync all streams - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in wilcoxon streaming: ") + - cudaGetErrorString(err)); - } - - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); -} - -/** - * Sparse-aware host-streaming CSC OVR pipeline. - * - * Like ovr_streaming_csc_host_impl but sorts only stored nonzeros per column - * instead of extracting dense blocks. GPU memory is O(max_batch_nnz) instead - * of O(sub_batch * n_rows), and sort work is proportional to nnz, not n_rows. - */ -template -static void ovr_sparse_csc_host_streaming_impl( - const InT* h_data, const int* h_indices, const IndptrT* h_indptr, - const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, - double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, - double* d_group_nnz, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, int sub_batch_cols) { - if (n_rows == 0 || n_cols == 0) return; - - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - - // Find max nnz across any sub-batch - size_t max_nnz = 0; - for (int col = 0; col < n_cols; col += sub_batch_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); - if (nnz > max_nnz) max_nnz = nnz; - } - - // CUB temp size for max_nnz items - size_t cub_temp_bytes = 0; - if (max_nnz > 0) { - auto* fk = reinterpret_cast(1); - auto* iv = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, - sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); - } - - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - - RmmPool pool; - int* d_group_codes = pool.alloc(n_rows); - double* d_group_sizes = pool.alloc(n_groups); - struct StreamBuf { - InT* d_sparse_data_orig; - float* d_sparse_data_f32; - int* d_sparse_indices; - int* d_seg_offsets; - float* keys_out; - int* vals_out; - uint8_t* cub_temp; - double* d_rank_sums; - double* d_tie_corr; - double* d_group_sums; - double* d_group_sq_sums; - double* d_group_nnz; - double* d_nz_scratch; // gmem-only; non-null when rank_use_gmem - }; - std::vector bufs(n_streams); - for (int s = 0; s < n_streams; s++) { - bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); - bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); - bufs[s].d_sparse_indices = pool.alloc(max_nnz); - bufs[s].d_seg_offsets = pool.alloc(sub_batch_cols + 1); - bufs[s].keys_out = pool.alloc(max_nnz); - bufs[s].vals_out = pool.alloc(max_nnz); - bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].d_rank_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); - bufs[s].d_group_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].d_group_sq_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].d_group_nnz = - pool.alloc((size_t)n_groups * sub_batch_cols); - } - - // Transfer group codes + sizes once - cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), - cudaMemcpyHostToDevice); - cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), - cudaMemcpyHostToDevice); - - // Pre-compute rebased per-batch offsets and upload once (avoids per-batch - // H2D copy from a transient host buffer). - int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); - for (int b = 0; b < n_batches; b++) { - int col_start = b * sub_batch_cols; - int sb = std::min(sub_batch_cols, n_cols - col_start); - IndptrT ptr_start = h_indptr[col_start]; - int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; - for (int i = 0; i <= sb; i++) - off[i] = (int)(h_indptr[col_start + i] - ptr_start); - } - int* d_all_offsets = - pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); - cudaMemcpy(d_all_offsets, h_all_offsets.data(), - h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); - - int tpb = UTIL_BLOCK_SIZE; - bool rank_use_gmem = false; - size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); - size_t smem_cast = (size_t)(3 * n_groups) * sizeof(double); - - // In gmem mode the sparse rank kernel accumulates into rank_sums directly - // and needs a per-stream nz_count scratch buffer sized (n_groups, sb_cols). - for (int s = 0; s < n_streams; s++) { - if (rank_use_gmem) { - bufs[s].d_nz_scratch = - pool.alloc((size_t)n_groups * sub_batch_cols); - } else { - bufs[s].d_nz_scratch = nullptr; - } - } - - // Pin only the host input arrays; outputs live on the device. - size_t total_nnz = (size_t)h_indptr[n_cols]; - HostRegisterGuard _pin_data(const_cast(h_data), - total_nnz * sizeof(InT)); - HostRegisterGuard _pin_indices(const_cast(h_indices), - total_nnz * sizeof(int)); - - cudaDeviceSynchronize(); - - int col = 0; - int batch_idx = 0; - while (col < n_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - int s = batch_idx % n_streams; - auto stream = streams[s]; - auto& buf = bufs[s]; - - IndptrT ptr_start = h_indptr[col]; - IndptrT ptr_end = h_indptr[col + sb_cols]; - int batch_nnz = (int)(ptr_end - ptr_start); - - // H2D: transfer sparse data for this column range (native dtype) - if (batch_nnz > 0) { - cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, - (size_t)batch_nnz * sizeof(InT), - cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, - (size_t)batch_nnz * sizeof(int), - cudaMemcpyHostToDevice, stream); - } - - // D2D: copy this batch's rebased offsets from the pre-uploaded buffer - int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); - cudaMemcpyAsync(buf.d_seg_offsets, src, (sb_cols + 1) * sizeof(int), - cudaMemcpyDeviceToDevice, stream); - - // Cast to float32 for sort + accumulate stats in float64 - ovr_cast_and_accumulate_sparse_kernel - <<>>( - buf.d_sparse_data_orig, buf.d_sparse_data_f32, - buf.d_sparse_indices, buf.d_seg_offsets, d_group_codes, - buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, - n_groups); - CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); - - // CUB sort only stored nonzeros (float32 keys) - if (batch_nnz > 0) { - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.d_sparse_data_f32, buf.keys_out, - buf.d_sparse_indices, buf.vals_out, batch_nnz, sb_cols, - buf.d_seg_offsets, buf.d_seg_offsets + 1, BEGIN_BIT, END_BIT, - stream); - } - - // Sparse rank kernel (stats already captured above) - if (rank_use_gmem) { - cudaMemsetAsync(buf.d_rank_sums, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - cudaMemsetAsync(buf.d_nz_scratch, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - } - rank_sums_sparse_ovr_kernel<<>>( - buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, - d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, buf.d_nz_scratch, - n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); - CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); - - // D2D: scatter sub-batch results into caller's GPU buffers - cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), - buf.d_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - if (compute_tie_corr) { - cudaMemcpyAsync(d_tie_corr + col, buf.d_tie_corr, - sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, - stream); - } - cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), - buf.d_group_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), - buf.d_group_sq_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), - buf.d_group_nnz, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - - col += sb_cols; - batch_idx++; - } - - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in sparse host CSC streaming: ") + - cudaGetErrorString(err)); - } - - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); -} - -/** - * Host-streaming dense OVR pipeline. - * - * Templated on the host dtype (InT = float or double). Each sub-batch is - * copied to the device in its native dtype once; a fused cast+accumulate - * kernel writes a float32 view for the sort and accumulates per-group - * sum/sum-sq/nnz in float64 from the original-precision values. The - * existing sort + rank pipeline then runs on the float32 keys. - * - * Output pointers ({d_rank_sums, d_tie_corr, d_group_sums, d_group_sq_sums, - * d_group_nnz}) point to caller-provided CuPy memory of the full output - * shape; sub-batch kernels scatter directly into them via D2D. - * - * GPU memory stays at O(sub_batch * n_rows), now with a small extra - * InT-sized sub-batch buffer per stream. - */ -template -static void ovr_streaming_dense_host_impl( - const InT* h_block, const int* h_group_codes, double* d_rank_sums, - double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, - double* d_group_nnz, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, int sub_batch_cols) { - if (n_rows == 0 || n_cols == 0) return; - - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - - size_t sub_items = (size_t)n_rows * sub_batch_cols; - size_t cub_temp_bytes = 0; - { - auto* fk = reinterpret_cast(1); - auto* iv = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, - sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); - } - - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - - // Allocate per-stream buffers via RMM pool - RmmPool pool; - int* d_group_codes = pool.alloc(n_rows); - struct StreamBuf { - InT* d_block_orig; - float* d_block_f32; - float* keys_out; - int* vals_in; - int* vals_out; - int* seg_offsets; - uint8_t* cub_temp; - double* d_rank_sums; - double* d_tie_corr; - double* d_group_sums; - double* d_group_sq_sums; - double* d_group_nnz; - }; - std::vector bufs(n_streams); - for (int s = 0; s < n_streams; s++) { - bufs[s].d_block_orig = pool.alloc(sub_items); - bufs[s].d_block_f32 = pool.alloc(sub_items); - bufs[s].keys_out = pool.alloc(sub_items); - bufs[s].vals_in = pool.alloc(sub_items); - bufs[s].vals_out = pool.alloc(sub_items); - bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); - bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].d_rank_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); - bufs[s].d_group_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].d_group_sq_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].d_group_nnz = - pool.alloc((size_t)n_groups * sub_batch_cols); - } - - // Group codes on GPU (transferred once) - cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), - cudaMemcpyHostToDevice); - - int tpb_rank = round_up_to_warp(n_rows); - bool use_gmem = false; - size_t smem_rank = ovr_smem_config(n_groups, use_gmem); - int tpb_cast = UTIL_BLOCK_SIZE; - size_t smem_cast = (size_t)(3 * n_groups) * sizeof(double); - - // Pin only the host input. Outputs live on the device (caller-owned). - HostRegisterGuard _pin_block(const_cast(h_block), - (size_t)n_rows * n_cols * sizeof(InT)); - - int col = 0; - int batch_idx = 0; - while (col < n_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_items = n_rows * sb_cols; - int s = batch_idx % n_streams; - auto stream = streams[s]; - auto& buf = bufs[s]; - - // H2D: column sub-batch in native dtype (F-order → contiguous) - cudaMemcpyAsync(buf.d_block_orig, h_block + (long long)col * n_rows, - sb_items * sizeof(InT), cudaMemcpyHostToDevice, stream); - - // Cast to float32 for sort + accumulate stats in float64 - ovr_cast_and_accumulate_dense_kernel - <<>>( - buf.d_block_orig, buf.d_block_f32, d_group_codes, - buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, n_rows, - sb_cols, n_groups); - CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_kernel); - - // Fill segment offsets + row indices - upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); - fill_row_indices_kernel<<>>( - buf.vals_in, n_rows, sb_cols); - CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); - - // Sort - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.d_block_f32, buf.keys_out, buf.vals_in, - buf.vals_out, sb_items, sb_cols, buf.seg_offsets, - buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); - - // Fused rank sums (stats already captured by the cast kernel) - if (use_gmem) { - cudaMemsetAsync(buf.d_rank_sums, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - } - rank_sums_from_sorted_kernel<<>>( - buf.keys_out, buf.vals_out, d_group_codes, buf.d_rank_sums, - buf.d_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, - use_gmem); - CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); - - // D2D: scatter sub-batch results into the caller's GPU buffers - cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), - buf.d_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - if (compute_tie_corr) { - cudaMemcpyAsync(d_tie_corr + col, buf.d_tie_corr, - sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, - stream); - } - cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), - buf.d_group_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), - buf.d_group_sq_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), - buf.d_group_nnz, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - - col += sb_cols; - batch_idx++; - } - - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in wilcoxon streaming: ") + - cudaGetErrorString(err)); - } - - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); -} - -// ============================================================================ -// Sparse-aware CSC OVR streaming (sort only stored nonzeros) -// ============================================================================ - -static void ovr_sparse_csc_streaming_impl( - const float* csc_data, const int* csc_indices, const int* csc_indptr, - const int* group_codes, const double* group_sizes, double* rank_sums, - double* tie_corr, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, int sub_batch_cols) { - if (n_rows == 0 || n_cols == 0) return; - - // Read indptr to host for batch planning - std::vector h_indptr(n_cols + 1); - cudaMemcpy(h_indptr.data(), csc_indptr, (n_cols + 1) * sizeof(int), - cudaMemcpyDeviceToHost); - - int n_streams = N_STREAMS; - if (n_cols < n_streams * sub_batch_cols) - n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - - // Find max nnz across any sub-batch for buffer sizing - size_t max_nnz = 0; - for (int col = 0; col < n_cols; col += sub_batch_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); - if (nnz > max_nnz) max_nnz = nnz; - } - - // CUB temp size for max_nnz items - size_t cub_temp_bytes = 0; - if (max_nnz > 0) { - auto* fk = reinterpret_cast(1); - auto* iv = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, - sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); - } - - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - - int tpb = UTIL_BLOCK_SIZE; - bool rank_use_gmem = false; - size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); - - RmmPool pool; - struct StreamBuf { - float* keys_out; - int* vals_out; - int* seg_offsets; - uint8_t* cub_temp; - double* sub_rank_sums; - double* sub_tie_corr; - double* d_nz_scratch; // gmem-only - }; - std::vector bufs(n_streams); - for (int s = 0; s < n_streams; s++) { - bufs[s].keys_out = pool.alloc(max_nnz); - bufs[s].vals_out = pool.alloc(max_nnz); - bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); - bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].sub_rank_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); - bufs[s].d_nz_scratch = - rank_use_gmem - ? pool.alloc((size_t)n_groups * sub_batch_cols) - : nullptr; - } - - cudaDeviceSynchronize(); - - int col = 0; - int batch_idx = 0; - while (col < n_cols) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - int s = batch_idx % n_streams; - auto stream = streams[s]; - auto& buf = bufs[s]; - - int ptr_start = h_indptr[col]; - int ptr_end = h_indptr[col + sb_cols]; - int batch_nnz = ptr_end - ptr_start; - - // Compute rebased segment offsets on GPU (avoids host pinned-buffer - // race) - { - int count = sb_cols + 1; - int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - rebase_indptr_kernel<<>>( - csc_indptr, buf.seg_offsets, col, count); - CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); - } - - // Sort only stored values (keys=data, vals=row_indices) - if (batch_nnz > 0) { - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, csc_data + ptr_start, buf.keys_out, - csc_indices + ptr_start, buf.vals_out, batch_nnz, sb_cols, - buf.seg_offsets, buf.seg_offsets + 1, BEGIN_BIT, END_BIT, - stream); - } - - // Sparse rank kernel (handles implicit zeros analytically) - if (rank_use_gmem) { - cudaMemsetAsync(buf.sub_rank_sums, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - cudaMemsetAsync(buf.d_nz_scratch, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - } - rank_sums_sparse_ovr_kernel<<>>( - buf.keys_out, buf.vals_out, buf.seg_offsets, group_codes, - group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, - n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); - CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); - - // Scatter results to global output - cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - buf.sub_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - if (compute_tie_corr) { - cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, - sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, - stream); - } - - col += sb_cols; - batch_idx++; - } - - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in sparse ovr streaming: ") + - cudaGetErrorString(err)); - } - - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); -} - -// ============================================================================ -// Sparse-aware CSR OVR streaming (partial CSR→CSC transpose per sub-batch) -// ============================================================================ - -/** - * Sparse-aware OVR streaming pipeline for GPU CSR data. - * - * Phase 0: One histogram kernel counts nnz per column. D2H + host prefix sums - * give exact per-batch nnz and max_batch_nnz for buffer sizing. - * Phase 1: Allocate per-stream buffers sized to max_batch_nnz. - * Phase 2: For each sub-batch: scatter CSR→CSC (partial transpose via - * atomics) → CUB sort only nonzeros → sparse rank kernel. - * - * Compared to the dense CSR path, sort work drops by ~1/sparsity. - */ -static void ovr_sparse_csr_streaming_impl( - const float* csr_data, const int* csr_indices, const int* csr_indptr, - const int* group_codes, const double* group_sizes, double* rank_sums, - double* tie_corr, int n_rows, int n_cols, int n_groups, - bool compute_tie_corr, int sub_batch_cols) { - if (n_rows == 0 || n_cols == 0) return; - - // ---- Phase 0: Planning — count nnz per column via histogram ---- - RmmPool pool; - int* d_col_counts = pool.alloc(n_cols); - cudaMemset(d_col_counts, 0, n_cols * sizeof(int)); - { - int blocks = (n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; - csr_col_histogram_kernel<<>>( - csr_indices, csr_indptr, d_col_counts, n_rows, n_cols); - CUDA_CHECK_LAST_ERROR(csr_col_histogram_kernel); - } - std::vector h_col_counts(n_cols); - cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(int), - cudaMemcpyDeviceToHost); - - // Per-batch prefix sums on host - int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; - size_t max_batch_nnz = 0; - - // Flat array: n_batches × (sub_batch_cols + 1) offsets - std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); - std::vector h_batch_nnz(n_batches); - - for (int b = 0; b < n_batches; b++) { - int col_start = b * sub_batch_cols; - int sb_cols = std::min(sub_batch_cols, n_cols - col_start); - int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; - off[0] = 0; - for (int i = 0; i < sb_cols; i++) - off[i + 1] = off[i] + h_col_counts[col_start + i]; - h_batch_nnz[b] = (size_t)off[sb_cols]; - if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; - } - - // Upload all batch offsets to GPU in one shot (~20 KB) - int* d_all_offsets = - pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); - cudaMemcpy(d_all_offsets, h_all_offsets.data(), - h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); - - // ---- Phase 1: Allocate per-stream buffers ---- - size_t cub_temp_bytes = 0; - if (max_batch_nnz > 0) { - auto* fk = reinterpret_cast(1); - auto* iv = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_batch_nnz, - sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); - } - - int n_streams = N_STREAMS; - if (n_batches < n_streams) n_streams = n_batches; - - // CSR path needs 4 sort arrays per stream (scatter intermediates + - // CUB output). Fit stream count to available GPU memory. - size_t per_stream_bytes = - max_batch_nnz * (2 * sizeof(float) + 2 * sizeof(int)) + - (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + - (size_t)n_groups * sub_batch_cols * sizeof(double) + - sub_batch_cols * sizeof(double); - - size_t free_mem = 0, total_mem = 0; - cudaMemGetInfo(&free_mem, &total_mem); - constexpr double MEM_BUDGET_FRAC = 0.8; - size_t budget = (size_t)(free_mem * MEM_BUDGET_FRAC); - while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) - n_streams--; - - std::vector streams(n_streams); - for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - - int tpb = UTIL_BLOCK_SIZE; - bool rank_use_gmem = false; - size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); - int scatter_blocks = (n_rows + tpb - 1) / tpb; - - struct StreamBuf { - int* col_offsets; // [sub_batch_cols + 1] CSC-style offsets - int* write_pos; // [sub_batch_cols] atomic write counters - float* csc_vals; // [max_batch_nnz] transposed values - int* csc_row_idx; // [max_batch_nnz] transposed row indices - float* keys_out; // [max_batch_nnz] CUB sort output - int* vals_out; // [max_batch_nnz] CUB sort output - uint8_t* cub_temp; - double* sub_rank_sums; - double* sub_tie_corr; - double* d_nz_scratch; // gmem-only - }; - std::vector bufs(n_streams); - for (int s = 0; s < n_streams; s++) { - bufs[s].col_offsets = pool.alloc(sub_batch_cols + 1); - bufs[s].write_pos = pool.alloc(sub_batch_cols); - bufs[s].csc_vals = pool.alloc(max_batch_nnz); - bufs[s].csc_row_idx = pool.alloc(max_batch_nnz); - bufs[s].keys_out = pool.alloc(max_batch_nnz); - bufs[s].vals_out = pool.alloc(max_batch_nnz); - bufs[s].cub_temp = pool.alloc(cub_temp_bytes); - bufs[s].sub_rank_sums = - pool.alloc((size_t)n_groups * sub_batch_cols); - bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); - bufs[s].d_nz_scratch = - rank_use_gmem - ? pool.alloc((size_t)n_groups * sub_batch_cols) - : nullptr; - } - - cudaDeviceSynchronize(); - - // ---- Phase 2: Stream loop ---- - int col = 0; - for (int b = 0; b < n_batches; b++) { - int sb_cols = std::min(sub_batch_cols, n_cols - col); - int s = b % n_streams; - auto stream = streams[s]; - auto& buf = bufs[s]; - int batch_nnz = (int)h_batch_nnz[b]; - - // D2D copy pre-computed col_offsets for this batch - int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); - cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), - cudaMemcpyDeviceToDevice, stream); - - // Initialize write_pos = col_offsets[0..sb_cols-1] (same D2D source) - cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), - cudaMemcpyDeviceToDevice, stream); - - if (batch_nnz > 0) { - // Scatter CSR → CSC layout for this sub-batch - csr_scatter_to_csc_kernel<<>>( - csr_data, csr_indices, csr_indptr, buf.write_pos, buf.csc_vals, - buf.csc_row_idx, n_rows, col, col + sb_cols); - CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); - - // CUB sort only the nonzeros - size_t temp = cub_temp_bytes; - cub::DeviceSegmentedRadixSort::SortPairs( - buf.cub_temp, temp, buf.csc_vals, buf.keys_out, buf.csc_row_idx, - buf.vals_out, batch_nnz, sb_cols, buf.col_offsets, - buf.col_offsets + 1, BEGIN_BIT, END_BIT, stream); - } - - // Sparse rank kernel (handles implicit zeros analytically) - if (rank_use_gmem) { - cudaMemsetAsync(buf.sub_rank_sums, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - cudaMemsetAsync(buf.d_nz_scratch, 0, - (size_t)n_groups * sb_cols * sizeof(double), - stream); - } - rank_sums_sparse_ovr_kernel<<>>( - buf.keys_out, buf.vals_out, buf.col_offsets, group_codes, - group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, - n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); - CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); - - // Scatter results to global output - cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), - buf.sub_rank_sums, sb_cols * sizeof(double), - sb_cols * sizeof(double), n_groups, - cudaMemcpyDeviceToDevice, stream); - if (compute_tie_corr) { - cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, - sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, - stream); - } - - col += sb_cols; - } - - for (int s = 0; s < n_streams; s++) { - cudaError_t err = cudaStreamSynchronize(streams[s]); - if (err != cudaSuccess) - throw std::runtime_error( - std::string("CUDA error in sparse CSR ovr streaming: ") + - cudaGetErrorString(err)); - } - - for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); -} - -// ============================================================================ -// Nanobind module -// ============================================================================ - -template -void register_bindings(nb::module_& m) { - m.doc() = "CUDA kernels for Wilcoxon rank-sum test (OVR)"; - - // ---- Streaming pipelines ---- - - m.def( - "ovr_streaming", - [](gpu_array_f block, - gpu_array_c group_codes, - gpu_array_c rank_sums, - gpu_array_c tie_corr, int n_rows, int n_cols, - int n_groups, bool compute_tie_corr, int sub_batch_cols) { - ovr_streaming_impl(block.data(), group_codes.data(), - rank_sums.data(), tie_corr.data(), n_rows, - n_cols, n_groups, compute_tie_corr, - sub_batch_cols); - }, - "block"_a, "group_codes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), - "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); - - m.def( - "ovr_sparse_csr", - [](gpu_array_c csr_data, - gpu_array_c csr_indices, - gpu_array_c csr_indptr, - gpu_array_c group_codes, - gpu_array_c group_sizes, - gpu_array_c rank_sums, - gpu_array_c tie_corr, int n_rows, int n_cols, - int n_groups, bool compute_tie_corr, int sub_batch_cols) { - ovr_sparse_csr_streaming_impl( - csr_data.data(), csr_indices.data(), csr_indptr.data(), - group_codes.data(), group_sizes.data(), rank_sums.data(), - tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, - sub_batch_cols); - }, - "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "group_codes"_a, - "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, - "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); - - m.def( - "ovr_sparse_csc", - [](gpu_array_c csc_data, - gpu_array_c csc_indices, - gpu_array_c csc_indptr, - gpu_array_c group_codes, - gpu_array_c group_sizes, - gpu_array_c rank_sums, - gpu_array_c tie_corr, int n_rows, int n_cols, - int n_groups, bool compute_tie_corr, int sub_batch_cols) { - ovr_sparse_csc_streaming_impl( - csc_data.data(), csc_indices.data(), csc_indptr.data(), - group_codes.data(), group_sizes.data(), rank_sums.data(), - tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, - sub_batch_cols); - }, - "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "group_codes"_a, - "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, - "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, - "sub_batch_cols"_a = SUB_BATCH_COLS); - - // ---- Host-streaming pipelines (host inputs, device outputs) ---- - -#define RSC_OVR_SPARSE_CSC_HOST_BINDING(NAME, InT, IndptrT) \ - m.def( \ - NAME, \ - [](host_array h_data, host_array h_indices, \ - host_array h_indptr, \ - host_array h_group_codes, \ - host_array h_group_sizes, \ - gpu_array_c d_rank_sums, \ - gpu_array_c d_tie_corr, \ - gpu_array_c d_group_sums, \ - gpu_array_c d_group_sq_sums, \ - gpu_array_c d_group_nnz, int n_rows, int n_cols, \ - int n_groups, bool compute_tie_corr, int sub_batch_cols) { \ - ovr_sparse_csc_host_streaming_impl( \ - h_data.data(), h_indices.data(), h_indptr.data(), \ - h_group_codes.data(), h_group_sizes.data(), \ - d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ - d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ - n_groups, compute_tie_corr, sub_batch_cols); \ - }, \ - "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ - "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ - "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, \ - "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ - "sub_batch_cols"_a = SUB_BATCH_COLS) - - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64", float, int64_t); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64", double, int); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, - int64_t); -#undef RSC_OVR_SPARSE_CSC_HOST_BINDING - -#define RSC_OVR_DENSE_HOST_BINDING(NAME, InT) \ - m.def( \ - NAME, \ - [](host_array_2d h_block, \ - host_array h_group_codes, \ - gpu_array_c d_rank_sums, \ - gpu_array_c d_tie_corr, \ - gpu_array_c d_group_sums, \ - gpu_array_c d_group_sq_sums, \ - gpu_array_c d_group_nnz, int n_rows, int n_cols, \ - int n_groups, bool compute_tie_corr, int sub_batch_cols) { \ - ovr_streaming_dense_host_impl( \ - h_block.data(), h_group_codes.data(), d_rank_sums.data(), \ - d_tie_corr.data(), d_group_sums.data(), \ - d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ - n_groups, compute_tie_corr, sub_batch_cols); \ - }, \ - "h_block"_a, "h_group_codes"_a, "d_rank_sums"_a, "d_tie_corr"_a, \ - "d_group_sums"_a, "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), \ - "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ - "sub_batch_cols"_a = SUB_BATCH_COLS) - - RSC_OVR_DENSE_HOST_BINDING("ovr_streaming_dense_host", float); - RSC_OVR_DENSE_HOST_BINDING("ovr_streaming_dense_host_f64", double); -#undef RSC_OVR_DENSE_HOST_BINDING -} +#include "wilcoxon_ovr_kernels.cuh" +#include "wilcoxon_ovr_dense.cuh" +#include "wilcoxon_ovr_sparse.cuh" +#include "wilcoxon_ovr_bindings.cuh" NB_MODULE(_wilcoxon_ovr_cuda, m) { REGISTER_GPU_BINDINGS(register_bindings, m); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_bindings.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_bindings.cuh new file mode 100644 index 00000000..fab780e9 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_bindings.cuh @@ -0,0 +1,137 @@ +#pragma once + +// ============================================================================ +// Nanobind module +// ============================================================================ + +template +void register_bindings(nb::module_& m) { + m.doc() = "CUDA kernels for Wilcoxon rank-sum test (OVR)"; + + // ---- Streaming pipelines ---- + + m.def( + "ovr_streaming", + [](gpu_array_f block, + gpu_array_c group_codes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_streaming_impl(block.data(), group_codes.data(), + rank_sums.data(), tie_corr.data(), n_rows, + n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "block"_a, "group_codes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovr_sparse_csr", + [](gpu_array_c csr_data, + gpu_array_c csr_indices, + gpu_array_c csr_indptr, + gpu_array_c group_codes, + gpu_array_c group_sizes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_sparse_csr_streaming_impl( + csr_data.data(), csr_indices.data(), csr_indptr.data(), + group_codes.data(), group_sizes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "group_codes"_a, + "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovr_sparse_csc", + [](gpu_array_c csc_data, + gpu_array_c csc_indices, + gpu_array_c csc_indptr, + gpu_array_c group_codes, + gpu_array_c group_sizes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_sparse_csc_streaming_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + group_codes.data(), group_sizes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "group_codes"_a, + "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + // ---- Host-streaming pipelines (host inputs, device outputs) ---- + +#define RSC_OVR_SPARSE_CSC_HOST_BINDING(NAME, InT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + host_array h_group_sizes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, int sub_batch_cols) { \ + ovr_sparse_csc_host_streaming_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), h_group_sizes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ + n_groups, compute_tie_corr, compute_sq_sums, compute_nnz, \ + sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, \ + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64", float, int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64", double, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, + int64_t); +#undef RSC_OVR_SPARSE_CSC_HOST_BINDING + +#define RSC_OVR_DENSE_HOST_BINDING(NAME, InT) \ + m.def( \ + NAME, \ + [](host_array_2d h_block, \ + host_array h_group_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, int sub_batch_cols) { \ + ovr_streaming_dense_host_impl( \ + h_block.data(), h_group_codes.data(), d_rank_sums.data(), \ + d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ + n_groups, compute_tie_corr, compute_sq_sums, compute_nnz, \ + sub_batch_cols); \ + }, \ + "h_block"_a, "h_group_codes"_a, "d_rank_sums"_a, "d_tie_corr"_a, \ + "d_group_sums"_a, "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), \ + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_DENSE_HOST_BINDING("ovr_streaming_dense_host", float); + RSC_OVR_DENSE_HOST_BINDING("ovr_streaming_dense_host_f64", double); +#undef RSC_OVR_DENSE_HOST_BINDING +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_dense.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_dense.cuh new file mode 100644 index 00000000..039f6ee6 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_dense.cuh @@ -0,0 +1,317 @@ +#pragma once + +/** + * Streaming OVR pipeline. + * + * Takes a dense F-order float32 block (n_rows, n_cols) + int32 group_codes, + * splits columns into sub-batches across multiple CUDA streams, and for each: + * 1. CUB SortPairs (float32 keys + int32 row indices) + * 2. Fused rank_sums_from_sorted_kernel + * + * Output: rank_sums (n_groups, n_cols) + tie_corr (n_cols), both float64. + */ +static void ovr_streaming_impl(const float* block, const int* group_codes, + double* rank_sums, double* tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_items = (size_t)n_rows * sub_batch_cols; + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + // Create streams + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Allocate per-stream buffers via RMM pool + RmmPool pool; + struct StreamBuf { + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + } + + int tpb_rank = round_up_to_warp(n_rows); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + + // Process sub-batches round-robin across streams + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = n_rows * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // Fill segment offsets + row indices + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); + fill_row_indices_kernel<<>>( + buf.vals_in, n_rows, sb_cols); + CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); + + // Sort: keys = block columns [col, col+sb_cols), already F-order + const float* keys_in = block + (long long)col * n_rows; + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, keys_in, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + + // Fused rank sums into sub-batch buffer + if (use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, + buf.sub_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, + use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); + + // Copy sub-batch results to global output (row-major scatter) + // rank_sums is (n_groups, n_cols) row-major: group g, col c → + // [g*n_cols+c] sub output is (n_groups, sb_cols): group g, local col lc + // → [g*sb_cols+lc] + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + batch_idx++; + } + + // Sync all streams + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * Host-streaming dense OVR pipeline. + * + * Templated on the host dtype (InT = float or double). Each sub-batch is + * copied to the device in its native dtype once; a fused cast+accumulate + * kernel writes a float32 view for the sort and accumulates per-group + * sum/sum-sq/nnz in float64 from the original-precision values. The + * existing sort + rank pipeline then runs on the float32 keys. + * + * Output pointers ({d_rank_sums, d_tie_corr, d_group_sums, d_group_sq_sums, + * d_group_nnz}) point to caller-provided CuPy memory of the full output + * shape; sub-batch kernels scatter directly into them via D2D. + * + * GPU memory stays at O(sub_batch * n_rows), now with a small extra + * InT-sized sub-batch buffer per stream. + */ +template +static void ovr_streaming_dense_host_impl( + const InT* h_block, const int* h_group_codes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_items = (size_t)n_rows * sub_batch_cols; + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Allocate per-stream buffers via RMM pool + RmmPool pool; + int* d_group_codes = pool.alloc(n_rows); + struct StreamBuf { + InT* d_block_orig; + float* d_block_f32; + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_block_orig = pool.alloc(sub_items); + bufs[s].d_block_f32 = pool.alloc(sub_items); + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sq_sums = + compute_sq_sums + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].d_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + // Group codes on GPU (transferred once) + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + + int tpb_rank = round_up_to_warp(n_rows); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + int tpb_cast = UTIL_BLOCK_SIZE; + size_t smem_cast = (size_t)n_groups * sizeof(double); + if (compute_nnz) { + smem_cast = (size_t)(3 * n_groups) * sizeof(double); + } else if (compute_sq_sums) { + smem_cast = (size_t)(2 * n_groups) * sizeof(double); + } + + // Pin only the host input. Outputs live on the device (caller-owned). + HostRegisterGuard _pin_block(const_cast(h_block), + (size_t)n_rows * n_cols * sizeof(InT)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = n_rows * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // H2D: column sub-batch in native dtype (F-order → contiguous) + cudaMemcpyAsync(buf.d_block_orig, h_block + (long long)col * n_rows, + sb_items * sizeof(InT), cudaMemcpyHostToDevice, stream); + + // Cast to float32 for sort + accumulate stats in float64 + ovr_cast_and_accumulate_dense_kernel + <<>>( + buf.d_block_orig, buf.d_block_f32, d_group_codes, + buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, n_rows, + sb_cols, n_groups, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_kernel); + + // Fill segment offsets + row indices + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); + fill_row_indices_kernel<<>>( + buf.vals_in, n_rows, sb_cols); + CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); + + // Sort + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.d_block_f32, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + + // Fused rank sums (stats already captured by the cast kernel) + if (use_gmem) { + cudaMemsetAsync(buf.d_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, d_group_codes, buf.d_rank_sums, + buf.d_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, + use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); + + // D2D: scatter sub-batch results into the caller's GPU buffers + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(d_tie_corr + col, buf.d_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh new file mode 100644 index 00000000..a94338b7 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -0,0 +1,102 @@ +#pragma once + +/** Count nonzeros per column from CSR. One thread per row. */ +__global__ void csr_col_histogram_kernel(const int* __restrict__ indices, + const int* __restrict__ indptr, + int* __restrict__ col_counts, + int n_rows, int n_cols) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + int rs = indptr[row]; + int re = indptr[row + 1]; + for (int p = rs; p < re; ++p) { + int c = indices[p]; + if (c < n_cols) atomicAdd(&col_counts[c], 1); + } +} + +/** + * Scatter CSR nonzeros into CSC layout for columns [col_start, col_stop). + * write_pos[c - col_start] must be initialized to the prefix-sum offset + * for column c. Each thread atomically claims a unique destination slot. + */ +__global__ void csr_scatter_to_csc_kernel( + const float* __restrict__ data, const int* __restrict__ indices, + const int* __restrict__ indptr, int* __restrict__ write_pos, + float* __restrict__ csc_vals, int* __restrict__ csc_row_idx, int n_rows, + int col_start, int col_stop) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + int rs = indptr[row]; + int re = indptr[row + 1]; + // Binary search for col_start (overflow-safe midpoint) + int lo = rs, hi = re; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + for (int p = lo; p < re; ++p) { + int c = indices[p]; + if (c >= col_stop) break; + int dest = atomicAdd(&write_pos[c - col_start], 1); + csc_vals[dest] = data[p]; + csc_row_idx[dest] = row; + } +} + +/** + * Decide whether to use shared or global memory for OVR rank accumulators. + * Returns the smem size to request and sets use_gmem accordingly. + */ +static int query_max_smem_per_block() { + static int cached = -1; + if (cached < 0) { + int device; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&cached, cudaDevAttrMaxSharedMemoryPerBlock, + device); + } + return cached; +} + +static size_t ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(n_groups + 32) * sizeof(double); + if ((int)need <= query_max_smem_per_block()) { + use_gmem = false; + return need; + } + // Fall back to global memory accumulators; only need warp buf in smem + use_gmem = true; + return 32 * sizeof(double); +} + +/** + * Decide smem-vs-gmem for the sparse OVR rank kernel. Two accumulator + * arrays (grp_sums + grp_nz_count) of size n_groups each plus warp buf. + */ +static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(2 * n_groups + 32) * sizeof(double); + if ((int)need <= query_max_smem_per_block()) { + use_gmem = false; + return need; + } + use_gmem = true; + return 32 * sizeof(double); +} + +/** + * Fill sort values with row indices [0,1,...,n_rows-1] per column. + * Grid: (n_cols,), block: 256 threads. + */ +__global__ void fill_row_indices_kernel(int* __restrict__ vals, int n_rows, + int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + int* out = vals + (long long)col * n_rows; + for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { + out[i] = i; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh new file mode 100644 index 00000000..725e2c41 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -0,0 +1,587 @@ +#pragma once + +/** + * Sparse-aware host-streaming CSC OVR pipeline. + * + * Like ovr_streaming_csc_host_impl but sorts only stored nonzeros per column + * instead of extracting dense blocks. GPU memory is O(max_batch_nnz) instead + * of O(sub_batch * n_rows), and sort work is proportional to nnz, not n_rows. + */ +template +static void ovr_sparse_csc_host_streaming_impl( + const InT* h_data, const int* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + // Find max nnz across any sub-batch + size_t max_nnz = 0; + for (int col = 0; col < n_cols; col += sub_batch_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); + if (nnz > max_nnz) max_nnz = nnz; + } + + // CUB temp size for max_nnz items + size_t cub_temp_bytes = 0; + if (max_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); + struct StreamBuf { + InT* d_sparse_data_orig; + float* d_sparse_data_f32; + int* d_sparse_indices; + int* d_seg_offsets; + float* keys_out; + int* vals_out; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; + double* d_nz_scratch; // gmem-only; non-null when rank_use_gmem + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); + bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].keys_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sq_sums = + compute_sq_sums + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].d_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + // Transfer group codes + sizes once + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + // Pre-compute rebased per-batch offsets and upload once (avoids per-batch + // H2D copy from a transient host buffer). + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb = std::min(sub_batch_cols, n_cols - col_start); + IndptrT ptr_start = h_indptr[col_start]; + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i <= sb; i++) + off[i] = (int)(h_indptr[col_start + i] - ptr_start); + } + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + size_t smem_cast = (size_t)n_groups * sizeof(double); + if (compute_nnz) { + smem_cast = (size_t)(3 * n_groups) * sizeof(double); + } else if (compute_sq_sums) { + smem_cast = (size_t)(2 * n_groups) * sizeof(double); + } + + // In gmem mode the sparse rank kernel accumulates into rank_sums directly + // and needs a per-stream nz_count scratch buffer sized (n_groups, sb_cols). + for (int s = 0; s < n_streams; s++) { + if (rank_use_gmem) { + bufs[s].d_nz_scratch = + pool.alloc((size_t)n_groups * sub_batch_cols); + } else { + bufs[s].d_nz_scratch = nullptr; + } + } + + // Pin only the host input arrays; outputs live on the device. + size_t total_nnz = (size_t)h_indptr[n_cols]; + HostRegisterGuard _pin_data(const_cast(h_data), + total_nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(int)); + + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = (int)(ptr_end - ptr_start); + + // H2D: transfer sparse data for this column range (native dtype) + if (batch_nnz > 0) { + cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, + (size_t)batch_nnz * sizeof(int), + cudaMemcpyHostToDevice, stream); + } + + // D2D: copy this batch's rebased offsets from the pre-uploaded buffer + int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.d_seg_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // Cast to float32 for sort + accumulate stats in float64 + ovr_cast_and_accumulate_sparse_kernel + <<>>( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, + buf.d_sparse_indices, buf.d_seg_offsets, d_group_codes, + buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, + n_groups, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); + + // CUB sort only stored nonzeros (float32 keys) + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.d_sparse_data_f32, buf.keys_out, + buf.d_sparse_indices, buf.vals_out, batch_nnz, sb_cols, + buf.d_seg_offsets, buf.d_seg_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + // Sparse rank kernel (stats already captured above) + if (rank_use_gmem) { + cudaMemsetAsync(buf.d_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, + d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // D2D: scatter sub-batch results into caller's GPU buffers + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(d_tie_corr + col, buf.d_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse host CSC streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Sparse-aware CSC OVR streaming (sort only stored nonzeros) +// ============================================================================ + +static void ovr_sparse_csc_streaming_impl( + const float* csc_data, const int* csc_indices, const int* csc_indptr, + const int* group_codes, const double* group_sizes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + // Read indptr to host for batch planning + std::vector h_indptr(n_cols + 1); + cudaMemcpy(h_indptr.data(), csc_indptr, (n_cols + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + // Find max nnz across any sub-batch for buffer sizing + size_t max_nnz = 0; + for (int col = 0; col < n_cols; col += sub_batch_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); + if (nnz > max_nnz) max_nnz = nnz; + } + + // CUB temp size for max_nnz items + size_t cub_temp_bytes = 0; + if (max_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + + RmmPool pool; + struct StreamBuf { + float* keys_out; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* d_nz_scratch; // gmem-only + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].keys_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + int ptr_start = h_indptr[col]; + int ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = ptr_end - ptr_start; + + // Compute rebased segment offsets on GPU (avoids host pinned-buffer + // race) + { + int count = sb_cols + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel<<>>( + csc_indptr, buf.seg_offsets, col, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Sort only stored values (keys=data, vals=row_indices) + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, csc_data + ptr_start, buf.keys_out, + csc_indices + ptr_start, buf.vals_out, batch_nnz, sb_cols, + buf.seg_offsets, buf.seg_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + // Sparse rank kernel (handles implicit zeros analytically) + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.seg_offsets, group_codes, + group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // Scatter results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse ovr streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Sparse-aware CSR OVR streaming (partial CSR→CSC transpose per sub-batch) +// ============================================================================ + +/** + * Sparse-aware OVR streaming pipeline for GPU CSR data. + * + * Phase 0: One histogram kernel counts nnz per column. D2H + host prefix sums + * give exact per-batch nnz and max_batch_nnz for buffer sizing. + * Phase 1: Allocate per-stream buffers sized to max_batch_nnz. + * Phase 2: For each sub-batch: scatter CSR→CSC (partial transpose via + * atomics) → CUB sort only nonzeros → sparse rank kernel. + * + * Compared to the dense CSR path, sort work drops by ~1/sparsity. + */ +static void ovr_sparse_csr_streaming_impl( + const float* csr_data, const int* csr_indices, const int* csr_indptr, + const int* group_codes, const double* group_sizes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + // ---- Phase 0: Planning — count nnz per column via histogram ---- + RmmPool pool; + int* d_col_counts = pool.alloc(n_cols); + cudaMemset(d_col_counts, 0, n_cols * sizeof(int)); + { + int blocks = (n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + csr_col_histogram_kernel<<>>( + csr_indices, csr_indptr, d_col_counts, n_rows, n_cols); + CUDA_CHECK_LAST_ERROR(csr_col_histogram_kernel); + } + std::vector h_col_counts(n_cols); + cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(int), + cudaMemcpyDeviceToHost); + + // Per-batch prefix sums on host + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + size_t max_batch_nnz = 0; + + // Flat array: n_batches × (sub_batch_cols + 1) offsets + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + std::vector h_batch_nnz(n_batches); + + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb_cols = std::min(sub_batch_cols, n_cols - col_start); + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + off[0] = 0; + for (int i = 0; i < sb_cols; i++) + off[i + 1] = off[i] + h_col_counts[col_start + i]; + h_batch_nnz[b] = (size_t)off[sb_cols]; + if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; + } + + // Upload all batch offsets to GPU in one shot (~20 KB) + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // ---- Phase 1: Allocate per-stream buffers ---- + size_t cub_temp_bytes = 0; + if (max_batch_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_batch_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + int n_streams = N_STREAMS; + if (n_batches < n_streams) n_streams = n_batches; + + // CSR path needs 4 sort arrays per stream (scatter intermediates + + // CUB output). Fit stream count to available GPU memory. + size_t per_stream_bytes = + max_batch_nnz * (2 * sizeof(float) + 2 * sizeof(int)) + + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + + (size_t)n_groups * sub_batch_cols * sizeof(double) + + sub_batch_cols * sizeof(double); + + size_t free_mem = 0, total_mem = 0; + cudaMemGetInfo(&free_mem, &total_mem); + constexpr double MEM_BUDGET_FRAC = 0.8; + size_t budget = (size_t)(free_mem * MEM_BUDGET_FRAC); + while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + n_streams--; + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + int scatter_blocks = (n_rows + tpb - 1) / tpb; + + struct StreamBuf { + int* col_offsets; // [sub_batch_cols + 1] CSC-style offsets + int* write_pos; // [sub_batch_cols] atomic write counters + float* csc_vals; // [max_batch_nnz] transposed values + int* csc_row_idx; // [max_batch_nnz] transposed row indices + float* keys_out; // [max_batch_nnz] CUB sort output + int* vals_out; // [max_batch_nnz] CUB sort output + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* d_nz_scratch; // gmem-only + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].col_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].write_pos = pool.alloc(sub_batch_cols); + bufs[s].csc_vals = pool.alloc(max_batch_nnz); + bufs[s].csc_row_idx = pool.alloc(max_batch_nnz); + bufs[s].keys_out = pool.alloc(max_batch_nnz); + bufs[s].vals_out = pool.alloc(max_batch_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + // ---- Phase 2: Stream loop ---- + int col = 0; + for (int b = 0; b < n_batches; b++) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = b % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + int batch_nnz = (int)h_batch_nnz[b]; + + // D2D copy pre-computed col_offsets for this batch + int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // Initialize write_pos = col_offsets[0..sb_cols-1] (same D2D source) + cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + // Scatter CSR → CSC layout for this sub-batch + csr_scatter_to_csc_kernel<<>>( + csr_data, csr_indices, csr_indptr, buf.write_pos, buf.csc_vals, + buf.csc_row_idx, n_rows, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + + // CUB sort only the nonzeros + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.csc_vals, buf.keys_out, buf.csc_row_idx, + buf.vals_out, batch_nnz, sb_cols, buf.col_offsets, + buf.col_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // Sparse rank kernel (handles implicit zeros analytically) + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.col_offsets, group_codes, + group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // Scatter results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse CSR ovr streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index 6a9aa016..866121e1 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys +import warnings from functools import partial from typing import TYPE_CHECKING, Literal @@ -8,6 +9,7 @@ import pandas as pd from ._core import _RankGenes +from ._utils import NoTestGroupsError if TYPE_CHECKING: from collections.abc import Iterable @@ -42,6 +44,7 @@ def rank_genes_groups( pre_load: bool = False, n_bins: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, + skip_empty_groups: bool = False, **kwds, ) -> None: """ @@ -121,6 +124,14 @@ def rank_genes_groups( ``'log1p'`` uses a fixed [0, 15] range suitable for most log1p-normalized data. ``'auto'`` computes the actual data range. Use this for z-scored or unnormalized data. + skip_empty_groups + If ``True``, silently drop groups with fewer than 2 cells (issuing + a ``RuntimeWarning``) instead of raising a ``ValueError``. Useful + when iterating over data subsets (e.g. per cell type) where some + categorical levels may be empty. If no test groups remain after + filtering (only the reference has >=2 cells), the call returns + without updating ``adata.uns``. The reference group is never + dropped — if it has <2 cells the call still fails. **kwds Additional arguments passed to the method. For `'logreg'`, these are passed to :class:`cuml.linear_model.LogisticRegression`. @@ -187,17 +198,29 @@ def rank_genes_groups( msg = f"mask_var has wrong shape: {mask_var_array.shape[0]} != {adata.n_vars}" raise ValueError(msg) - test_obj = _RankGenes( - adata, - groups, - groupby, - mask_var=mask_var_array, - reference=reference, - use_raw=use_raw, - layer=layer, - comp_pts=pts, - pre_load=pre_load, - ) + try: + test_obj = _RankGenes( + adata, + groups, + groupby, + mask_var=mask_var_array, + reference=reference, + use_raw=use_raw, + layer=layer, + comp_pts=pts, + pre_load=pre_load, + skip_empty_groups=skip_empty_groups, + ) + except NoTestGroupsError as e: + # skip_empty_groups=True contract: no test groups left → no-op. + # Do not write to adata.uns so downstream loops can detect the + # missing key and skip this subset. + warnings.warn( + f"rank_genes_groups: skipping — {e}", + RuntimeWarning, + stacklevel=2, + ) + return # Determine n_genes_user n_genes_user = n_genes @@ -217,11 +240,21 @@ def rank_genes_groups( **kwds, ) - # Build output - test_obj.stats.columns = test_obj.stats.columns.swaplevel() - + # Use a U-width tight to the actual gene names rather than the scanpy + # default of U50. For a 1948-group × 18k-gene workload this cuts the + # names structured array from ~7 GB → ~3 GB. Must match the width that + # compute_statistics used when converting var_names (see _core.py), so + # the final stack → structured-array view is a pure memcpy. + _vn = np.asarray(test_obj.var_names) + if _vn.dtype.kind == "U": + max_name_len = _vn.dtype.itemsize // 4 + elif len(_vn): + max_name_len = max(len(str(n)) for n in _vn) + else: + max_name_len = 50 + names_dtype = f"U{max(max_name_len, 1)}" dtypes = { - "names": "U50", + "names": names_dtype, "scores": "float32", "logfoldchanges": "float32", "pvals": "float64", @@ -252,13 +285,47 @@ def rank_genes_groups( if method == "wilcoxon": adata.uns[key_added]["params"]["tie_correct"] = tie_correct - for col in test_obj.stats.columns.levels[0]: - if col in dtypes: - adata.uns[key_added][col] = test_obj.stats[col].to_records( - index=False, column_dtypes=dtypes[col] + # Assemble scanpy-compatible structured arrays directly from per-group + # arrays, without going through a wide pandas DataFrame + to_records — + # that pipeline was ~4 s of pure Python overhead on workloads with + # thousands of groups. + group_names = test_obj.group_names + if group_names: + for stat, per_group_arrays in test_obj.results.items(): + if per_group_arrays is None or stat not in dtypes: + continue + adata.uns[key_added][stat] = _build_structured( + group_names, per_group_arrays, dtypes[stat] ) +def _build_structured( + group_names: list[str], + per_group_arrays: list[np.ndarray], + field_dtype: str, +) -> np.ndarray: + """Build a scanpy-style structured array with one field per group. + + Equivalent to assigning ``sa[name] = arr`` in a loop, but ~20× faster + on wide workloads (1000+ groups): fills a contiguous 2-D (n, n_groups) + buffer row-by-row and reinterprets it as a structured view, so the + bulk write pattern matches the structured-array memory layout. + """ + n = per_group_arrays[0].shape[0] + # Stack directly into the target dtype — np.stack with `dtype=` avoids + # the separate astype pass. axis=1 produces (n, n_groups) C-contig + # exactly matching the structured memory layout below, so .view() is + # zero-copy. Use 'unsafe' casting for string targets (object→U50). + if np.dtype(field_dtype).kind == "U": + stacked = np.stack( + per_group_arrays, axis=1, casting="unsafe", dtype=field_dtype + ) + else: + stacked = np.stack(per_group_arrays, axis=1, dtype=field_dtype) + dtype = np.dtype([(name, field_dtype) for name in group_names]) + return stacked.view(dtype).reshape(n) + + if TYPE_CHECKING: from warnings import deprecated else: diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index d89a079a..03a24e24 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -1,18 +1,28 @@ from __future__ import annotations +from concurrent.futures import ThreadPoolExecutor +from os import cpu_count from typing import TYPE_CHECKING, Literal, assert_never import cupy as cp import numpy as np import pandas as pd -from statsmodels.stats.multitest import multipletests from rapids_singlecell._compat import DaskArray from rapids_singlecell.get import X_to_GPU from rapids_singlecell.get._aggregated import Aggregate from rapids_singlecell.preprocessing._utils import _check_gpu_X -from ._utils import EPS, _select_groups, _select_top_n +from ._utils import ( + EPS, + _benjamini_hochberg, + _select_groups, + _select_top_n, +) + +POSTPROCESS_PARALLEL_GROUPS = 256 +POSTPROCESS_PARALLEL_GENES = 1024 +POSTPROCESS_MAX_WORKERS = 8 if TYPE_CHECKING: from collections.abc import Iterable @@ -38,6 +48,7 @@ def __init__( layer: str | None = None, comp_pts: bool = False, pre_load: bool = False, + skip_empty_groups: bool = False, ) -> None: # Handle groups parameter if groups == "all" or groups is None: @@ -63,7 +74,10 @@ def __init__( raise ValueError(msg) self.groups_order, self.group_codes, self.group_sizes = _select_groups( - self.labels, selected + self.labels, + selected, + skip_empty_groups=skip_empty_groups, + reference=reference, ) # Get data matrix @@ -114,9 +128,19 @@ def __init__( self.vars_rest: np.ndarray | None = None self.pts_rest: np.ndarray | None = None - self.stats: pd.DataFrame | None = None + # Per-stat × per-group arrays. results[stat] is either None (stat not + # computed) or a list of 1-D numpy arrays in group_names order. This + # replaces the old DataFrame-based pipeline, which burned ~4 s per call + # on wide workloads (1000+ groups) in pandas DataFrame + to_records. + self.results: dict[str, list[np.ndarray] | None] = { + "names": None, + "scores": None, + "pvals": None, + "pvals_adj": None, + "logfoldchanges": None, + } + self.group_names: list[str] = [] self._compute_stats_in_chunks: bool = False - self._ref_chunk_computed: set[int] = set() def _init_stats_arrays(self, n_genes: int) -> None: """Pre-allocate stats arrays before chunk loop.""" @@ -213,104 +237,6 @@ def _basic_stats(self) -> None: self.vars = cp.asnumpy(vars_) self.pts = cp.asnumpy(pts) if pts is not None else None - def _accumulate_chunk_stats_vs_rest( - self, - block: cp.ndarray, - start: int, - stop: int, - *, - group_matrix: cp.ndarray, - group_sizes_dev: cp.ndarray, - n_cells: int, - ) -> None: - """Compute and store stats for one gene chunk (vs rest mode).""" - if not self._compute_stats_in_chunks: - return # Stats already computed via Aggregate - - rest_sizes = n_cells - group_sizes_dev - - # Group sums and sum of squares - group_sums = group_matrix.T @ block - group_sum_sq = group_matrix.T @ (block**2) - - # Means - chunk_means = group_sums / group_sizes_dev[:, None] - self.means[:, start:stop] = cp.asnumpy(chunk_means) - - # Variances (with Bessel correction) - chunk_vars = group_sum_sq / group_sizes_dev[:, None] - chunk_means**2 - chunk_vars *= group_sizes_dev[:, None] / (group_sizes_dev[:, None] - 1) - self.vars[:, start:stop] = cp.asnumpy(chunk_vars) - - # Pts (fraction expressing) - if self.comp_pts: - group_nnz = group_matrix.T @ (block != 0).astype(cp.float64) - self.pts[:, start:stop] = cp.asnumpy(group_nnz / group_sizes_dev[:, None]) - - # Rest statistics - if self.ireference is None: - total_sum = block.sum(axis=0) - total_sum_sq = (block**2).sum(axis=0) - - rest_sums = total_sum[None, :] - group_sums - rest_means = rest_sums / rest_sizes[:, None] - self.means_rest[:, start:stop] = cp.asnumpy(rest_means) - - rest_sum_sq = total_sum_sq[None, :] - group_sum_sq - rest_vars = rest_sum_sq / rest_sizes[:, None] - rest_means**2 - rest_vars *= rest_sizes[:, None] / (rest_sizes[:, None] - 1) - self.vars_rest[:, start:stop] = cp.asnumpy(rest_vars) - - if self.comp_pts: - total_nnz = (block != 0).sum(axis=0) - rest_nnz = total_nnz[None, :] - group_nnz - self.pts_rest[:, start:stop] = cp.asnumpy( - rest_nnz / rest_sizes[:, None] - ) - - def _accumulate_chunk_stats_with_ref( - self, - block: cp.ndarray, - start: int, - stop: int, - *, - group_index: int, - group_mask_gpu: cp.ndarray, - n_group: int, - n_ref: int, - ) -> None: - """Compute and store stats for one gene chunk (with reference mode).""" - if not self._compute_stats_in_chunks: - return # Stats already computed via Aggregate - - # Group stats - group_data = block[group_mask_gpu] - group_mean = group_data.mean(axis=0) - self.means[group_index, start:stop] = cp.asnumpy(group_mean) - - if n_group > 1: - group_var = group_data.var(axis=0, ddof=1) - self.vars[group_index, start:stop] = cp.asnumpy(group_var) - - if self.comp_pts: - group_nnz = (group_data != 0).sum(axis=0) - self.pts[group_index, start:stop] = cp.asnumpy(group_nnz / n_group) - - # Reference stats (only compute once, on first non-reference group) - if start not in self._ref_chunk_computed: - self._ref_chunk_computed.add(start) - ref_data = block[~group_mask_gpu] - ref_mean = ref_data.mean(axis=0) - self.means[self.ireference, start:stop] = cp.asnumpy(ref_mean) - - if n_ref > 1: - ref_var = ref_data.var(axis=0, ddof=1) - self.vars[self.ireference, start:stop] = cp.asnumpy(ref_var) - - if self.comp_pts: - ref_nnz = (ref_data != 0).sum(axis=0) - self.pts[self.ireference, start:stop] = cp.asnumpy(ref_nnz / n_ref) - def t_test( self, method: Literal["t-test", "t-test_overestim_var"] ) -> list[tuple[int, NDArray, NDArray]]: @@ -408,56 +334,136 @@ def compute_statistics( n_genes = self.X.shape[1] - # Collect all stats data first to avoid DataFrame fragmentation - stats_data: dict[tuple[str, str], np.ndarray] = {} - - for group_index, scores, pvals in test_results: - group_name = str(self.groups_order[group_index]) + if not test_results: + self.group_names = [] + return + group_indices = np.array([gi for gi, _, _ in test_results], dtype=np.int64) + self.group_names = [str(self.groups_order[gi]) for gi in group_indices] + has_lfc = self.means is not None + + # Vectorised log-fold-change across all test groups — avoids 1948 + # individual numpy ops in the hot path (was ~1.5 s of Python + # overhead on wide workloads). Cast to the output dtype here so + # _build_structured never needs a full-array astype pass. + lfc_all: np.ndarray | None = None + if has_lfc: + mean_groups = self.means[group_indices] + if self.ireference is None: + mean_rests = self.means_rest[group_indices] + else: + mean_rests = self.means[self.ireference][None, :] + foldchanges = (self.expm1_func(mean_groups) + EPS) / ( + self.expm1_func(mean_rests) + EPS + ) + lfc_all = np.log2(foldchanges).astype(np.float32, copy=False) + + names_list: list[np.ndarray] = [] + scores_list: list[np.ndarray] = [] + pvals_list: list[np.ndarray] = [] + pvals_adj_list: list[np.ndarray] = [] + lfc_list: list[np.ndarray] = [] + has_pvals = False + # Pre-convert var_names to fixed-width unicode ONCE. Without this, + # per-group indexing returns object arrays, and np.stack(..., dtype='U') + # ends up doing ~35 M object→string conversions inside the hot loop — + # that was ~1.5 s on the 1948-group / 18k-gene workload. Using the + # target width directly turns the final stack into a pure memcpy. + _vn = np.asarray(self.var_names) + if _vn.dtype.kind == "U": + var_names_arr = _vn + else: + max_len = max((len(str(n)) for n in _vn), default=1) + var_names_arr = _vn.astype(f"U{max_len}") + + def _process_result( + ti: int, + ) -> tuple[ + int, + np.ndarray | None, + np.ndarray, + np.ndarray | None, + np.ndarray | None, + np.ndarray | None, + ]: + _, scores, pvals = test_results[ti] if n_genes_user is not None: scores_sort = np.abs(scores) if rankby_abs else scores global_indices = _select_top_n(scores_sort, n_genes_user) + names = var_names_arr[global_indices] else: global_indices = slice(None) + names = None - if n_genes_user is not None: - stats_data[group_name, "names"] = np.asarray(self.var_names)[ - global_indices - ] - - stats_data[group_name, "scores"] = scores[global_indices] + scores_out = scores[global_indices] + pvals_out = None + pvals_adj_out = None if pvals is not None: - stats_data[group_name, "pvals"] = pvals[global_indices] + pvals_out = pvals[global_indices] if corr_method == "benjamini-hochberg": - pvals_clean = np.array(pvals, copy=True) - pvals_clean[np.isnan(pvals_clean)] = 1.0 - _, pvals_adj, _, _ = multipletests( - pvals_clean, alpha=0.05, method="fdr_bh" - ) + pvals_adj = _benjamini_hochberg(pvals) elif corr_method == "bonferroni": pvals_adj = np.minimum(pvals * n_genes, 1.0) - stats_data[group_name, "pvals_adj"] = pvals_adj[global_indices] - - # Compute logfoldchanges - if self.means is not None: - mean_group = self.means[group_index] - if self.ireference is None: - mean_rest = self.means_rest[group_index] - else: - mean_rest = self.means[self.ireference] - foldchanges = (self.expm1_func(mean_group) + EPS) / ( - self.expm1_func(mean_rest) + EPS - ) - stats_data[group_name, "logfoldchanges"] = np.log2( - foldchanges[global_indices] - ) - - # Create DataFrame all at once to avoid fragmentation - if stats_data: - self.stats = pd.DataFrame(stats_data) - self.stats.columns = pd.MultiIndex.from_tuples(self.stats.columns) - if n_genes_user is None: - self.stats.index = self.var_names + pvals_adj_out = pvals_adj[global_indices] + + lfc_out = None + if lfc_all is not None: + lfc_out = lfc_all[ti][global_indices] + + return ti, names, scores_out, pvals_out, pvals_adj_out, lfc_out + + def _process_range( + start: int, stop: int + ) -> list[ + tuple[ + int, + np.ndarray | None, + np.ndarray, + np.ndarray | None, + np.ndarray | None, + np.ndarray | None, + ] + ]: + return [_process_result(ti) for ti in range(start, stop)] + + n_results = len(test_results) + use_parallel_post = ( + n_results >= POSTPROCESS_PARALLEL_GROUPS + and n_genes >= POSTPROCESS_PARALLEL_GENES + ) + if use_parallel_post: + workers = min(POSTPROCESS_MAX_WORKERS, cpu_count() or 1, n_results) + chunk = (n_results + workers - 1) // workers + ranges = [ + (start, min(start + chunk, n_results)) + for start in range(0, n_results, chunk) + ] + with ThreadPoolExecutor(max_workers=workers) as executor: + processed_chunks = executor.map(lambda r: _process_range(*r), ranges) + processed = [ + item for chunk_out in processed_chunks for item in chunk_out + ] else: - self.stats = None + processed = _process_range(0, n_results) + + for _, names, scores_out, pvals_out, pvals_adj_out, lfc_out in processed: + if names is not None: + names_list.append(names) + scores_list.append(scores_out) + if pvals_out is not None and pvals_adj_out is not None: + has_pvals = True + pvals_list.append(pvals_out) + pvals_adj_list.append(pvals_adj_out) + if lfc_out is not None: + lfc_list.append(lfc_out) + + if self.group_names: + self.results["scores"] = scores_list + if n_genes_user is not None: + self.results["names"] = names_list + if has_pvals: + self.results["pvals"] = pvals_list + self.results["pvals_adj"] = pvals_adj_list + if lfc_all is not None: + self.results["logfoldchanges"] = lfc_list diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index 890ff6e7..c46ec26d 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -1,26 +1,30 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING -import cupy as cp -import cupyx.scipy.sparse as cpsp import numpy as np -import scipy.sparse as sp - -from rapids_singlecell.preprocessing._utils import _sparse_to_dense if TYPE_CHECKING: import pandas as pd from numpy.typing import NDArray EPS = 1e-9 -WARP_SIZE = 32 -MAX_THREADS_PER_BLOCK = 512 + + +class NoTestGroupsError(ValueError): + """Raised when skip_empty_groups=True and no test groups remain after + filtering. The public ``rank_genes_groups`` catches this and returns + quietly (after emitting a ``RuntimeWarning``) so callers can iterate + over data subsets without wrapping each call in try/except.""" def _select_groups( labels: pd.Series, selected: list | None, + *, + skip_empty_groups: bool = False, + reference: str | None = None, ) -> tuple[NDArray, NDArray[np.int32], NDArray[np.int64]]: """Build integer group codes from a categorical Series. @@ -31,6 +35,17 @@ def _select_groups( selected Group names to keep, or ``None`` for all groups. Must already include the reference group if applicable. + skip_empty_groups + If ``True``, drop groups with fewer than 2 cells instead of raising, + emitting a ``RuntimeWarning`` that lists the dropped groups. Useful + when iterating over data subsets (e.g. per cell type) where some + categorical levels have no cells. The reference group is never + silently dropped — if it has <2 cells, a ``ValueError`` is raised. + reference + Name of the reference group (``None`` or ``"rest"`` if there is + no fixed reference). Only used when ``skip_empty_groups`` is + ``True`` to validate that the reference is not the one being + dropped. Returns ------- @@ -51,27 +66,65 @@ def _select_groups( cat_order = {str(c): i for i, c in enumerate(all_categories)} selected.sort(key=lambda x: cat_order.get(str(x), len(all_categories))) + # First pass: compute sizes for the currently-selected groups so we can + # optionally drop empty/singleton ones before assigning final codes. + orig_codes_all = labels.cat.codes.to_numpy() + + def _compute_codes_and_sizes( + sel: list, + ) -> tuple[NDArray[np.int32], NDArray[np.int64]]: + n = len(sel) + str_to_sel = {str(name): idx for idx, name in enumerate(sel)} + lookup = np.full(len(all_categories) + 1, n, dtype=np.int32) + for cat_idx, cat_name in enumerate(all_categories): + sel_idx = str_to_sel.get(str(cat_name)) + if sel_idx is not None: + lookup[cat_idx] = sel_idx + codes = lookup[orig_codes_all] + sizes = np.bincount(codes, minlength=n + 1)[:n].astype(np.int64) + return codes, sizes + + _, preview_sizes = _compute_codes_and_sizes(selected) + + if skip_empty_groups: + ref_str = str(reference) if reference not in (None, "rest") else None + empty_idx = [i for i, s in enumerate(preview_sizes) if s < 2] + if empty_idx: + empty_names = [str(selected[i]) for i in empty_idx] + if ref_str is not None and ref_str in empty_names: + msg = ( + f"Reference group {ref_str!r} has <2 cells; cannot run " + "with skip_empty_groups=True." + ) + raise ValueError(msg) + warnings.warn( + f"Dropping {len(empty_names)} group(s) with <2 cells: " + f"{', '.join(empty_names)}", + RuntimeWarning, + stacklevel=3, + ) + selected = [g for i, g in enumerate(selected) if i not in set(empty_idx)] + + # Need at least one test group once the reference is excluded. + ref_str = str(reference) if reference not in (None, "rest") else None + n_test = sum(1 for g in selected if str(g) != ref_str) + if n_test == 0: + msg = ( + "No test groups with >=2 cells remain after filtering " + "(only the reference has enough cells)." + ) + raise NoTestGroupsError(msg) + n_groups = len(selected) groups_order = np.array(selected) - # Map original category index → selected group index - str_to_sel = {str(name): idx for idx, name in enumerate(selected)} - orig_to_sel: dict[int, int] = {} - for cat_idx, cat_name in enumerate(all_categories): - sel_idx = str_to_sel.get(str(cat_name)) - if sel_idx is not None: - orig_to_sel[cat_idx] = sel_idx - - orig_codes = labels.cat.codes.to_numpy() - group_codes = np.full(len(orig_codes), n_groups, dtype=np.int32) - for orig_idx, sel_idx in orig_to_sel.items(): - group_codes[orig_codes == orig_idx] = sel_idx + if n_groups == 0: + msg = "No groups with >=2 cells remain after filtering." + raise ValueError(msg) - group_sizes = np.bincount(group_codes, minlength=n_groups + 1)[:n_groups].astype( - np.int64 - ) + group_codes, group_sizes = _compute_codes_and_sizes(selected) - # Validate singlet groups + # Validate singlet groups (only triggers when skip_empty_groups=False). invalid_groups = {str(selected[i]) for i in range(n_groups) if group_sizes[i] < 2} if invalid_groups: msg = ( @@ -83,56 +136,31 @@ def _select_groups( return groups_order, group_codes, group_sizes -def _round_up_to_warp(n: int) -> int: - """Round up to nearest multiple of WARP_SIZE, capped at MAX_THREADS_PER_BLOCK.""" - return min(MAX_THREADS_PER_BLOCK, ((n + WARP_SIZE - 1) // WARP_SIZE) * WARP_SIZE) - - def _select_top_n(scores: NDArray, n_top: int) -> NDArray: """Select indices of top n scores. Uses argpartition + argsort for O(n + k log k) complexity where k = n_top. This is faster than full sorting when k << n. """ - n_from = scores.shape[0] - reference_indices = np.arange(n_from, dtype=int) + if n_top >= scores.shape[0]: + return np.argsort(scores)[::-1] partition = np.argpartition(scores, -n_top)[-n_top:] - partial_indices = np.argsort(scores[partition])[::-1] - global_indices = reference_indices[partition][partial_indices] - return global_indices + return partition[np.argsort(scores[partition])[::-1]] -def _csc_columns_to_gpu(X_csc, start: int, stop: int, n_rows: int) -> cp.ndarray: - """ - Extract columns from a CSC matrix via direct indptr pointer slicing. +def _benjamini_hochberg(pvals: NDArray) -> NDArray: + """Adjust p-values with the Benjamini-Hochberg FDR procedure.""" + pvals_clean = np.array(pvals, copy=True) + pvals_clean[np.isnan(pvals_clean)] = 1.0 - Works for both scipy and CuPy CSC matrices. Much faster than - ``X[:, start:stop]`` which rebuilds index arrays internally. - """ - s_ptr = int(X_csc.indptr[start]) - e_ptr = int(X_csc.indptr[stop]) - chunk_data = cp.asarray(X_csc.data[s_ptr:e_ptr]) - chunk_indices = cp.asarray(X_csc.indices[s_ptr:e_ptr]) - chunk_indptr = cp.asarray(X_csc.indptr[start : stop + 1] - s_ptr) - csc_chunk = cpsp.csc_matrix( - (chunk_data, chunk_indices, chunk_indptr), shape=(n_rows, stop - start) - ) - return _sparse_to_dense(csc_chunk, order="F") - - -def _get_column_block(X, start: int, stop: int) -> cp.ndarray: - """Extract a column block as a dense F-order CuPy array (native dtype).""" - match X: - case sp.csc_matrix() | sp.csc_array(): - return _csc_columns_to_gpu(X, start, stop, X.shape[0]) - case sp.spmatrix() | sp.sparray(): - chunk = cpsp.csc_matrix(X[:, start:stop].tocsc()) - return _sparse_to_dense(chunk, order="F") - case cpsp.csc_matrix(): - return _csc_columns_to_gpu(X, start, stop, X.shape[0]) - case cpsp.spmatrix(): - return _sparse_to_dense(X[:, start:stop], order="F") - case np.ndarray() | cp.ndarray(): - return cp.asfortranarray(cp.asarray(X[:, start:stop])) - case _: - raise ValueError(f"Unsupported matrix type: {type(X)}") + n_tests = pvals_clean.size + order = np.argsort(pvals_clean) + ordered = pvals_clean[order] + ranks = np.arange(1, n_tests + 1, dtype=ordered.dtype) / n_tests + adjusted = ordered / ranks + np.minimum.accumulate(adjusted[::-1], out=adjusted[::-1]) + np.minimum(adjusted, 1.0, out=adjusted) + + out = np.empty_like(adjusted) + out[order] = adjusted + return out diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 1773c5cd..45833d78 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -33,6 +33,7 @@ def _fill_basic_stats_from_accumulators( group_sizes: np.ndarray, *, n_cells: int, + compute_vars: bool = True, ) -> None: """Populate rg.means/vars/pts (+ *_rest) from streamed accumulators. @@ -44,22 +45,30 @@ def _fill_basic_stats_from_accumulators( """ n = cp.asarray(group_sizes, dtype=cp.float64)[:, None] means = group_sums / n - group_ss = group_sq_sums - n * means**2 - vars_ = cp.maximum(group_ss / cp.maximum(n - 1, 1), 0) rg.means = cp.asnumpy(means) - rg.vars = cp.asnumpy(vars_) + if compute_vars: + group_ss = group_sq_sums - n * means**2 + vars_ = cp.maximum(group_ss / cp.maximum(n - 1, 1), 0) + rg.vars = cp.asnumpy(vars_) + else: + rg.vars = np.zeros_like(rg.means) rg.pts = cp.asnumpy(group_nnz / n) if rg.comp_pts else None if rg.ireference is None: n_rest = cp.float64(n_cells) - n total_sum = group_sums.sum(axis=0, keepdims=True) - total_sq_sum = group_sq_sums.sum(axis=0, keepdims=True) rest_sums = total_sum - group_sums rest_means = rest_sums / n_rest - rest_ss = (total_sq_sum - group_sq_sums) - n_rest * rest_means**2 rg.means_rest = cp.asnumpy(rest_means) - rg.vars_rest = cp.asnumpy(cp.maximum(rest_ss / cp.maximum(n_rest - 1, 1), 0)) + if compute_vars: + total_sq_sum = group_sq_sums.sum(axis=0, keepdims=True) + rest_ss = (total_sq_sum - group_sq_sums) - n_rest * rest_means**2 + rg.vars_rest = cp.asnumpy( + cp.maximum(rest_ss / cp.maximum(n_rest - 1, 1), 0) + ) + else: + rg.vars_rest = np.zeros_like(rg.means_rest) if rg.comp_pts: total_nnz = group_nnz.sum(axis=0, keepdims=True) rg.pts_rest = cp.asnumpy((total_nnz - group_nnz) / n_rest) @@ -82,6 +91,7 @@ def _fill_ovo_stats_from_accumulators( group_sizes: NDArray, test_group_indices: list[int], n_ref: int, + compute_vars: bool = True, ) -> None: """Populate rg.means/vars/pts from OVO stats slots. @@ -102,10 +112,10 @@ def _fill(slot: int, size: int, gi: int) -> None: if size <= 0: return sums = group_sums_slots[slot] - sq = group_sq_sums_slots[slot] mean = sums / size rg.means[gi] = cp.asnumpy(mean) - if size > 1: + if compute_vars and size > 1: + sq = group_sq_sums_slots[slot] ss = sq - size * mean**2 rg.vars[gi] = cp.asnumpy(cp.maximum(ss / max(size - 1, 1), 0)) if rg.comp_pts: @@ -366,11 +376,19 @@ def _wilcoxon_vs_rest( # counts into caller-provided CuPy buffers, so means/vars/pts can # be derived without uploading the full matrix. Outputs live on # the GPU and feed directly into the z-score / p-value math below. + compute_vars = False + compute_nnz = rg.comp_pts rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) tie_corr = cp.ones(n_total_genes, dtype=cp.float64) group_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) - group_sq_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) - group_nnz = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + group_sq_sums = cp.empty( + (n_groups, n_total_genes) if compute_vars else (1, 1), + dtype=cp.float64, + ) + group_nnz = cp.empty( + (n_groups, n_total_genes) if compute_nnz else (1, 1), + dtype=cp.float64, + ) if host_csc: group_sizes_np = group_sizes.astype(np.float64, copy=False) @@ -403,6 +421,8 @@ def _wilcoxon_vs_rest( n_cols=n_total_genes, n_groups=n_groups, compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, sub_batch_cols=STREAMING_SUB_BATCH, ) else: @@ -425,6 +445,8 @@ def _wilcoxon_vs_rest( n_cols=n_total_genes, n_groups=n_groups, compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, sub_batch_cols=STREAMING_SUB_BATCH, ) @@ -436,6 +458,7 @@ def _wilcoxon_vs_rest( group_nnz, group_sizes.astype(np.float64, copy=False), n_cells=n_cells, + compute_vars=compute_vars, ) else: # GPU data or host CSR → transfer to GPU, use GPU kernels @@ -453,10 +476,14 @@ def _wilcoxon_vs_rest( if cpsp.isspmatrix_csc(X_gpu): # Sparse-aware path: sort only stored nonzeros, # handle zeros analytically. + csc_data = X_gpu.data.astype(cp.float32, copy=False) + csc_indices = X_gpu.indices.astype(cp.int32, copy=False) + csc_indptr = X_gpu.indptr.astype(cp.int32, copy=False) + cp.cuda.get_current_stream().synchronize() _ovr.ovr_sparse_csc( - X_gpu.data.astype(cp.float32, copy=False), - X_gpu.indices.astype(cp.int32, copy=False), - X_gpu.indptr.astype(cp.int32, copy=False), + csc_data, + csc_indices, + csc_indptr, group_codes_gpu, group_sizes_dev, rank_sums, @@ -468,10 +495,18 @@ def _wilcoxon_vs_rest( sub_batch_cols=STREAMING_SUB_BATCH, ) elif cpsp.isspmatrix_csr(X_gpu): + csr_gpu = X_gpu + if not csr_gpu.has_sorted_indices: + csr_gpu = csr_gpu.copy() + csr_gpu.sort_indices() + csr_data = csr_gpu.data.astype(cp.float32, copy=False) + csr_indices = csr_gpu.indices.astype(cp.int32, copy=False) + csr_indptr = csr_gpu.indptr.astype(cp.int32, copy=False) + cp.cuda.get_current_stream().synchronize() _ovr.ovr_sparse_csr( - X_gpu.data.astype(cp.float32, copy=False), - X_gpu.indices.astype(cp.int32, copy=False), - X_gpu.indptr.astype(cp.int32, copy=False), + csr_data, + csr_indices, + csr_indptr, group_codes_gpu, group_sizes_dev, rank_sums, @@ -484,6 +519,7 @@ def _wilcoxon_vs_rest( ) else: dense_f32 = cp.asfortranarray(X_gpu.astype(cp.float32, copy=False)) + cp.cuda.get_current_stream().synchronize() _ovr.ovr_streaming( dense_f32, group_codes_gpu, @@ -508,7 +544,7 @@ def _wilcoxon_vs_rest( cp.nan_to_num(z, copy=False) p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - all_z = z.get() + all_z = z.astype(cp.float32).get() all_p = p_values.get() return [(gi, all_z[gi], all_p[gi]) for gi in range(n_groups)] @@ -540,53 +576,85 @@ def _wilcoxon_with_reference( n_ref = int(group_sizes[ireference]) codes = rg.group_codes - # ---- build row-index arrays ---- - test_group_indices: list[int] = [] - all_grp_rows: list[np.ndarray] = [] - offsets = [0] - for gi in range(n_groups): - if gi == ireference: - continue - rows = np.where(codes == gi)[0] - test_group_indices.append(gi) - all_grp_rows.append(rows) - offsets.append(offsets[-1] + len(rows)) - + # ---- build row-index arrays via CSR-style cat offsets (O(n)) ---- + # scipy's coo→csr conversion is a vectorised counting sort — ~15× faster + # than np.argsort(stable) on this shape (1.8 ms vs 27 ms for 534 k cells). + # indptr[g] .. indptr[g+1] bracket the rows of group g in indices. + n_rows_incl_sentinel = n_groups + 1 # last slot holds "unselected" rows + _csr = sp.coo_matrix( + ( + np.ones(n_cells, dtype=np.int8), + (codes, np.arange(n_cells, dtype=np.int32)), + ), + shape=(n_rows_incl_sentinel, n_cells), + ).tocsr() + offsets_full = _csr.indptr # int32, length n_groups + 2 + sorted_rows = _csr.indices # int32, length n_cells + + test_group_indices: list[int] = [gi for gi in range(n_groups) if gi != ireference] if not test_group_indices: return [] + test_group_indices_np = np.asarray(test_group_indices, dtype=np.intp) + + offsets = [0] + all_grp_rows_parts: list[np.ndarray] = [] + for gi in test_group_indices: + g_start = int(offsets_full[gi]) + g_end = int(offsets_full[gi + 1]) + all_grp_rows_parts.append(sorted_rows[g_start:g_end]) + offsets.append(offsets[-1] + (g_end - g_start)) - all_grp_row_ids_np = np.concatenate(all_grp_rows) - grp_offsets_gpu = cp.asarray(offsets, dtype=cp.int32) + all_grp_row_ids_np = np.concatenate(all_grp_rows_parts) n_test = len(test_group_indices) n_all_grp = len(all_grp_row_ids_np) - ref_row_ids_np = np.where(codes == ireference)[0] - - # ---- warn for small groups ---- - for gi in test_group_indices: - n_group = int(group_sizes[gi]) - if n_group <= MIN_GROUP_SIZE_WARNING or n_ref <= MIN_GROUP_SIZE_WARNING: - warnings.warn( - f"Group {rg.groups_order[gi]} has size {n_group} " - f"(reference {n_ref}); normal approximation " - "of the Wilcoxon statistic may be inaccurate.", - RuntimeWarning, - stacklevel=4, + ref_row_ids_np = sorted_rows[ + int(offsets_full[ireference]) : int(offsets_full[ireference + 1]) + ] + + # ---- warn for small groups (single aggregated warning for the batch + # rather than one per group — emitting a warning per group is O(n_test) + # Python overhead that dominates on workloads with thousands of groups). + small_test = [ + str(rg.groups_order[gi]) + for gi in test_group_indices + if int(group_sizes[gi]) <= MIN_GROUP_SIZE_WARNING + ] + ref_small = n_ref <= MIN_GROUP_SIZE_WARNING + if small_test or ref_small: + parts = [] + if small_test: + parts.append( + f"{len(small_test)} test group(s) have size " + f"<= {MIN_GROUP_SIZE_WARNING} (first few: " + f"{', '.join(small_test[:5])}{'...' if len(small_test) > 5 else ''})" ) + if ref_small: + parts.append(f"reference has size {n_ref}") + warnings.warn( + f"Small groups detected: {'; '.join(parts)}. normal " + "approximation of the Wilcoxon statistic may be inaccurate.", + RuntimeWarning, + stacklevel=4, + ) test_sizes = cp.asarray( - [group_sizes[gi] for gi in test_group_indices], dtype=cp.float64 + group_sizes[test_group_indices_np].astype(np.float64, copy=False) ) - # ---- build row maps (numpy, for both host and GPU CSC paths) ---- - ref_row_map_np = np.full(n_cells, -1, dtype=np.int32) - ref_row_map_np[ref_row_ids_np] = np.arange(n_ref, dtype=np.int32) - grp_row_map_np = np.full(n_cells, -1, dtype=np.int32) - grp_row_map_np[all_grp_row_ids_np] = np.arange(n_all_grp, dtype=np.int32) offsets_np = np.asarray(offsets, dtype=np.int32) # ---- host-streaming paths: skip bulk transfer ---- host_sparse = isinstance(X, sp.spmatrix | sp.sparray) host_dense = isinstance(X, np.ndarray) + + # ---- build row maps only for paths that need original-row lookup ---- + ref_row_map_np = grp_row_map_np = None + if (host_sparse and X.format == "csc") or (not host_sparse and not host_dense): + ref_row_map_np = np.full(n_cells, -1, dtype=np.int32) + ref_row_map_np[ref_row_ids_np] = np.arange(n_ref, dtype=np.int32) + grp_row_map_np = np.full(n_cells, -1, dtype=np.int32) + grp_row_map_np[all_grp_row_ids_np] = np.arange(n_all_grp, dtype=np.int32) + if host_sparse or host_dense: # Output buffers live on the GPU (caller-provided CuPy memory); # kernels write directly into them, and rank_sums / tie_corr @@ -598,30 +666,51 @@ def _wilcoxon_with_reference( # Unselected cells carry the sentinel (n_groups_stats) which the # kernel skips. n_groups_stats = n_test + 1 - stats_codes_np = np.full(n_cells, n_groups_stats, dtype=np.int32) - for i, gi in enumerate(test_group_indices): - stats_codes_np[codes == gi] = i - stats_codes_np[codes == ireference] = n_test + compute_vars = False + compute_nnz = rg.comp_pts + stats_code_lookup = np.full(n_groups + 1, n_groups_stats, dtype=np.int32) + stats_code_lookup[test_group_indices_np] = np.arange(n_test, dtype=np.int32) + stats_code_lookup[ireference] = n_test + stats_codes_np = stats_code_lookup[codes] group_sums = cp.empty((n_groups_stats, n_total_genes), dtype=cp.float64) - group_sq_sums = cp.empty((n_groups_stats, n_total_genes), dtype=cp.float64) - group_nnz = cp.empty((n_groups_stats, n_total_genes), dtype=cp.float64) + group_sq_sums = cp.empty( + (n_groups_stats, n_total_genes) if compute_vars else (1,), + dtype=cp.float64, + ) + group_nnz = cp.empty( + (n_groups_stats, n_total_genes) if compute_nnz else (1,), + dtype=cp.float64, + ) if host_sparse and X.format == "csc": is_f64 = X.data.dtype == np.float64 + is_idx64 = X.indices.dtype == np.int64 is_i64 = X.indptr.dtype == np.int64 - if is_f64 and is_i64: - _csc_host_fn = _wc.ovo_streaming_csc_host_f64_i64 - elif is_f64: - _csc_host_fn = _wc.ovo_streaming_csc_host_f64 + if is_f64: + if is_idx64 and is_i64: + _csc_host_fn = _wc.ovo_streaming_csc_host_f64_idx64_i64 + elif is_idx64: + _csc_host_fn = _wc.ovo_streaming_csc_host_f64_idx64 + elif is_i64: + _csc_host_fn = _wc.ovo_streaming_csc_host_f64_i64 + else: + _csc_host_fn = _wc.ovo_streaming_csc_host_f64 + elif is_idx64 and is_i64: + _csc_host_fn = _wc.ovo_streaming_csc_host_idx64_i64 + elif is_idx64: + _csc_host_fn = _wc.ovo_streaming_csc_host_idx64 elif is_i64: _csc_host_fn = _wc.ovo_streaming_csc_host_i64 else: _csc_host_fn = _wc.ovo_streaming_csc_host data_arr = X.data if is_f64 else X.data.astype(np.float32, copy=False) + indices_arr = ( + X.indices if is_idx64 else X.indices.astype(np.int32, copy=False) + ) _csc_host_fn( data_arr, - X.indices.astype(np.int32, copy=False), + indices_arr, X.indptr, ref_row_map_np, grp_row_map_np, @@ -639,42 +728,63 @@ def _wilcoxon_with_reference( n_groups=n_test, n_groups_stats=n_groups_stats, compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, sub_batch_cols=STREAMING_SUB_BATCH, ) elif host_sparse: csr = X.tocsr() if X.format != "csr" else X + if not csr.has_sorted_indices: + csr = csr.copy() + csr.sort_indices() is_f64 = csr.data.dtype == np.float64 + is_idx64 = csr.indices.dtype == np.int64 is_i64 = csr.indptr.dtype == np.int64 - if is_f64 and is_i64: - _csr_host_fn = _wc.ovo_streaming_csr_host_f64_i64 - elif is_f64: - _csr_host_fn = _wc.ovo_streaming_csr_host_f64 + + # Zero-copy mapped: pin full CSR, upload indptr + row_ids, GPU + # kernels gather per-pack rows via UVA reads. + if is_f64: + if is_idx64 and is_i64: + _csr_host_fn = _wc.ovo_streaming_csr_host_f64_idx64_i64 + elif is_idx64: + _csr_host_fn = _wc.ovo_streaming_csr_host_f64_idx64 + elif is_i64: + _csr_host_fn = _wc.ovo_streaming_csr_host_f64_i64 + else: + _csr_host_fn = _wc.ovo_streaming_csr_host_f64 + elif is_idx64 and is_i64: + _csr_host_fn = _wc.ovo_streaming_csr_host_idx64_i64 + elif is_idx64: + _csr_host_fn = _wc.ovo_streaming_csr_host_idx64 elif is_i64: _csr_host_fn = _wc.ovo_streaming_csr_host_i64 else: _csr_host_fn = _wc.ovo_streaming_csr_host data_arr = csr.data if is_f64 else csr.data.astype(np.float32, copy=False) + indices_arr = ( + csr.indices if is_idx64 else csr.indices.astype(np.int32, copy=False) + ) _csr_host_fn( data_arr, - csr.indices.astype(np.int32, copy=False), + indices_arr, csr.indptr, ref_row_ids_np.astype(np.int32, copy=False), all_grp_row_ids_np.astype(np.int32, copy=False), offsets_np, - stats_codes_np, rank_sums, tie_corr_arr, group_sums, group_sq_sums, group_nnz, + n_full_rows=n_cells, n_ref=n_ref, n_all_grp=n_all_grp, - n_rows=n_cells, n_cols=n_total_genes, - n_groups=n_test, + n_test=n_test, n_groups_stats=n_groups_stats, - nnz=csr.nnz, compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, sub_batch_cols=STREAMING_SUB_BATCH, ) else: @@ -703,6 +813,8 @@ def _wilcoxon_with_reference( n_groups=n_test, n_groups_stats=n_groups_stats, compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, sub_batch_cols=STREAMING_SUB_BATCH, ) @@ -715,11 +827,13 @@ def _wilcoxon_with_reference( group_sizes=group_sizes, test_group_indices=test_group_indices, n_ref=n_ref, + compute_vars=compute_vars, ) else: # ---- GPU path: transfer once, then dispatch ---- X_gpu = _to_gpu_native(X, n_cells, n_total_genes) + grp_offsets_gpu = cp.asarray(offsets_np, dtype=cp.int32) if rg._compute_stats_in_chunks: rg.X = X_gpu @@ -735,10 +849,14 @@ def _wilcoxon_with_reference( if cpsp.isspmatrix_csc(X_gpu): ref_row_map = cp.asarray(ref_row_map_np) grp_row_map = cp.asarray(grp_row_map_np) + csc_data = X_gpu.data.astype(cp.float32, copy=False) + csc_indices = X_gpu.indices.astype(cp.int32, copy=False) + csc_indptr = X_gpu.indptr.astype(cp.int32, copy=False) + cp.cuda.get_current_stream().synchronize() _wc.ovo_streaming_csc( - X_gpu.data.astype(cp.float32, copy=False), - X_gpu.indices.astype(cp.int32, copy=False), - X_gpu.indptr.astype(cp.int32, copy=False), + csc_data, + csc_indices, + csc_indptr, ref_row_map, grp_row_map, grp_offsets_gpu, @@ -754,10 +872,17 @@ def _wilcoxon_with_reference( elif cpsp.issparse(X_gpu): # CSR-native: extract ref/grp rows directly csr_gpu = X_gpu.tocsr() if not cpsp.isspmatrix_csr(X_gpu) else X_gpu + if not csr_gpu.has_sorted_indices: + csr_gpu = csr_gpu.copy() + csr_gpu.sort_indices() + csr_data = csr_gpu.data.astype(cp.float32, copy=False) + csr_indices = csr_gpu.indices.astype(cp.int32, copy=False) + csr_indptr = csr_gpu.indptr.astype(cp.int32, copy=False) + cp.cuda.get_current_stream().synchronize() _wc.ovo_streaming_csr( - csr_gpu.data.astype(cp.float32, copy=False), - csr_gpu.indices.astype(cp.int32, copy=False), - csr_gpu.indptr.astype(cp.int32, copy=False), + csr_data, + csr_indices, + csr_indptr, ref_row_ids_gpu, all_grp_row_ids_gpu, grp_offsets_gpu, @@ -784,6 +909,7 @@ def _wilcoxon_with_reference( 1, ) grp_f32 = cp.asfortranarray(grp_block.astype(cp.float32, copy=False)) + cp.cuda.get_current_stream().synchronize() _wc.ovo_streaming( ref_sorted, grp_f32, @@ -813,7 +939,10 @@ def _wilcoxon_with_reference( cp.nan_to_num(z, copy=False) p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - all_z = z.get() + # Downcast scores to float32 on GPU (the scanpy output dtype); keeps + # p-values in float64 for downstream BH correction precision. Moving + # the cast off CPU saves ~150 ms per stat per call on wide workloads. + all_z = z.astype(cp.float32).get() all_p = p_values.get() return [(gi, all_z[ti], all_p[ti]) for ti, gi in enumerate(test_group_indices)] From 3030c363625f4a0e7803375309b43d78ddac0fb5 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 24 Apr 2026 02:48:47 +0200 Subject: [PATCH 21/21] make PR safer --- .gitignore | 1 + notebooks | 2 +- pyproject.toml | 15 +- .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 159 ++++++++ .../_cuda/wilcoxon/wilcoxon_common.cuh | 9 +- .../wilcoxon/wilcoxon_ovo_host_dense.cuh | 15 +- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 43 ++- .../_cuda/wilcoxon/wilcoxon_ovr_bindings.cuh | 46 +++ .../_cuda/wilcoxon/wilcoxon_ovr_dense.cuh | 20 +- .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 32 +- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 300 ++++++++++++++- .../tools/_rank_genes_groups/__init__.py | 8 +- .../tools/_rank_genes_groups/_core.py | 13 +- .../tools/_rank_genes_groups/_wilcoxon.py | 352 +++++++++++++----- tests/test_rank_genes_groups_wilcoxon.py | 116 ++++++ 15 files changed, 947 insertions(+), 184 deletions(-) diff --git a/.gitignore b/.gitignore index b7a8a4f6..2735c080 100644 --- a/.gitignore +++ b/.gitignore @@ -51,3 +51,4 @@ CLAUDE.md # tmp_scripts tmp_scripts/ +benchmarks/ diff --git a/notebooks b/notebooks index 4cdaa44f..e5c97b34 160000 --- a/notebooks +++ b/notebooks @@ -1 +1 @@ -Subproject commit 4cdaa44fbd93b6f812fc8d2c72b89180ef92047d +Subproject commit e5c97b34f4acbf919fb3118c987cc5893e5b5fdf diff --git a/pyproject.toml b/pyproject.toml index 7903bcf1..3777740e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,10 @@ requires = [ "scikit-build-core>=0.10", "nanobind>=2.0.0", "setuptools-scm>=8", + # librmm headers are needed at build time for the Wilcoxon CUDA kernels. + # The headers are identical across cu12/cu13; the runtime .so is loaded + # from the installed RAPIDS package. + "librmm-cu12>=25.10", ] build-backend = "scikit_build_core.build" @@ -197,14 +201,3 @@ source = [ "./src", "**/site-packages" ] exclude_also = [ "if TYPE_CHECKING:", ] - -[tool.uv] -# librmm headers are needed at build time for the wilcoxon CUDA kernels. -# The headers are identical across cu12/cu13 — only the .so differs (loaded -# at runtime via librmm.load_library()). cu12 is used here as the build-time -# provider; cu13 envs get the same headers. -# NOTE: This is uv-specific. pip users building from source need: -# pip install librmm-cu12 && pip install --no-build-isolation -e . - -[tool.uv.extra-build-dependencies] -rapids-singlecell = [ "librmm-cu12>=25.10" ] diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index bcd70dc7..0ce922ae 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -406,6 +406,35 @@ __global__ void rank_sums_sparse_ovr_kernel( } } +/** + * Decide whether the host cast+stats kernels can use per-block shared memory + * accumulators. Large group counts exceed the dynamic smem launch limit, so + * those cases fall back to direct global-memory atomics after zeroing the + * per-stream output buffers. + */ +static int wilcoxon_cast_max_smem_per_block() { + static int cached = -1; + if (cached < 0) { + int device; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&cached, cudaDevAttrMaxSharedMemoryPerBlock, + device); + } + return cached; +} + +static size_t cast_accumulate_smem_config(int n_groups, bool compute_sq_sums, + bool compute_nnz, bool& use_gmem) { + int n_arrays = 1 + (compute_sq_sums ? 1 : 0) + (compute_nnz ? 1 : 0); + size_t need = (size_t)n_arrays * n_groups * sizeof(double); + if (need <= (size_t)wilcoxon_cast_max_smem_per_block()) { + use_gmem = false; + return need; + } + use_gmem = true; + return 0; +} + /** * Pre-sort cast-and-accumulate kernel for dense OVR host streaming. * @@ -467,6 +496,36 @@ __global__ void ovr_cast_and_accumulate_dense_kernel( } } +template +__global__ void ovr_cast_and_accumulate_dense_global_kernel( + const InT* __restrict__ block_in, float* __restrict__ block_f32_out, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + const InT* src = block_in + (size_t)col * n_rows; + float* dst = block_f32_out + (size_t)col * n_rows; + + for (int r = threadIdx.x; r < n_rows; r += blockDim.x) { + InT v_in = src[r]; + double v = (double)v_in; + dst[r] = (float)v_in; + int g = group_codes[r]; + if (g < n_groups) { + atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)g * sb_cols + col], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); + } + } + } +} + /** * Pre-sort cast-and-accumulate kernel for sparse OVR host streaming. * @@ -529,3 +588,103 @@ __global__ void ovr_cast_and_accumulate_sparse_kernel( } } } + +template +__global__ void ovr_cast_and_accumulate_sparse_global_kernel( + const InT* __restrict__ data_in, float* __restrict__ data_f32_out, + const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + + for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { + InT v_in = data_in[i]; + double v = (double)v_in; + data_f32_out[i] = (float)v_in; + int row = (int)indices[i]; + int g = group_codes[row]; + if (g < n_groups) { + atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)g * sb_cols + col], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); + } + } + } +} + +template +static void launch_ovr_cast_and_accumulate_dense( + const InT* d_block_orig, float* d_block_f32, const int* d_group_codes, + double* d_group_sums, double* d_group_sq_sums, double* d_group_nnz, + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums, + bool compute_nnz, int tpb, size_t smem_cast, bool use_gmem, + cudaStream_t stream) { + if (use_gmem) { + size_t stats_items = (size_t)n_groups * sb_cols; + cudaMemsetAsync(d_group_sums, 0, stats_items * sizeof(double), stream); + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, stats_items * sizeof(double), + stream); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, stats_items * sizeof(double), + stream); + } + ovr_cast_and_accumulate_dense_global_kernel + <<>>( + d_block_orig, d_block_f32, d_group_codes, d_group_sums, + d_group_sq_sums, d_group_nnz, n_rows, sb_cols, n_groups, + compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_global_kernel); + } else { + ovr_cast_and_accumulate_dense_kernel + <<>>( + d_block_orig, d_block_f32, d_group_codes, d_group_sums, + d_group_sq_sums, d_group_nnz, n_rows, sb_cols, n_groups, + compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_kernel); + } +} + +template +static void launch_ovr_cast_and_accumulate_sparse( + const InT* d_data_orig, float* d_data_f32, const IndexT* d_indices, + const int* d_col_offsets, const int* d_group_codes, double* d_group_sums, + double* d_group_sq_sums, double* d_group_nnz, int sb_cols, int n_groups, + bool compute_sq_sums, bool compute_nnz, int tpb, size_t smem_cast, + bool use_gmem, cudaStream_t stream) { + if (use_gmem) { + size_t stats_items = (size_t)n_groups * sb_cols; + cudaMemsetAsync(d_group_sums, 0, stats_items * sizeof(double), stream); + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, stats_items * sizeof(double), + stream); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, stats_items * sizeof(double), + stream); + } + ovr_cast_and_accumulate_sparse_global_kernel + <<>>( + d_data_orig, d_data_f32, d_indices, d_col_offsets, + d_group_codes, d_group_sums, d_group_sq_sums, d_group_nnz, + sb_cols, n_groups, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_global_kernel); + } else { + ovr_cast_and_accumulate_sparse_kernel + <<>>( + d_data_orig, d_data_f32, d_indices, d_col_offsets, + d_group_codes, d_group_sums, d_group_sq_sums, d_group_nnz, + sb_cols, n_groups, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh index c0456b3f..8ac0f247 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh @@ -53,9 +53,9 @@ struct HostRegisterGuard { void* ptr = nullptr; HostRegisterGuard() = default; - HostRegisterGuard(void* p, size_t bytes, unsigned int flags = 0) : ptr(p) { - if (ptr) { - cudaError_t err = cudaHostRegister(ptr, bytes, flags); + HostRegisterGuard(void* p, size_t bytes, unsigned int flags = 0) { + if (p && bytes > 0) { + cudaError_t err = cudaHostRegister(p, bytes, flags); if (err != cudaSuccess) { // Already-registered memory is fine; anything else means the // subsequent kernels would read garbage from an unmapped @@ -63,13 +63,14 @@ struct HostRegisterGuard { if (err == cudaErrorHostMemoryAlreadyRegistered) { cudaGetLastError(); // clear sticky error flag } else { - ptr = nullptr; // don't unregister in dtor throw std::runtime_error( std::string("cudaHostRegister failed (") + std::to_string((size_t)bytes) + " bytes, flags=" + std::to_string(flags) + "): " + cudaGetErrorString(err)); } + } else { + ptr = p; } } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_dense.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_dense.cuh index 37d23ab4..458d6667 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_dense.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_dense.cuh @@ -160,7 +160,9 @@ static void ovo_streaming_dense_host_impl( int tpb_rank = round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); - size_t smem_cast = (size_t)(3 * n_groups_stats) * sizeof(double); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups_stats, compute_sq_sums, compute_nnz, cast_use_gmem); // Pin only the host input; outputs live on the device. HostRegisterGuard _pin_block(const_cast(h_block), @@ -182,12 +184,11 @@ static void ovo_streaming_dense_host_impl( sb_dense * sizeof(InT), cudaMemcpyHostToDevice, stream); // ---- Cast to float32 for sort + accumulate stats in float64 ---- - ovr_cast_and_accumulate_dense_kernel - <<>>( - buf.d_block_orig, buf.d_block_f32, d_stats_codes, - buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, n_rows, - sb_cols, n_groups_stats, compute_sq_sums, compute_nnz); - CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_kernel); + launch_ovr_cast_and_accumulate_dense( + buf.d_block_orig, buf.d_block_f32, d_stats_codes, buf.d_group_sums, + buf.d_group_sq_sums, buf.d_group_nnz, n_rows, sb_cols, + n_groups_stats, compute_sq_sums, compute_nnz, UTIL_BLOCK_SIZE, + smem_cast, cast_use_gmem, stream); // ---- Gather ref rows, sort ---- dense_gather_rows_kernel<<>>( diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index f2b4e7ae..32721a41 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -75,6 +75,21 @@ static void ovo_streaming_csc_host_impl( RmmPool pool; + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb = std::min(sub_batch_cols, n_cols - col_start); + IndptrT ptr_start = h_indptr[col_start]; + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i <= sb; i++) + off[i] = (int)(h_indptr[col_start + i] - ptr_start); + } + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + // GPU copies of row maps + group offsets + stats codes (uploaded once) int* d_ref_row_map = pool.alloc(n_rows); int* d_grp_row_map = pool.alloc(n_rows); @@ -154,7 +169,9 @@ static void ovo_streaming_csc_host_impl( int tpb_rank = round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); - size_t smem_cast = (size_t)(3 * n_groups_stats) * sizeof(double); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups_stats, compute_sq_sums, compute_nnz, cast_use_gmem); // Pin only the sparse input arrays; outputs live on the device. size_t total_nnz = (size_t)h_indptr[n_cols]; @@ -181,22 +198,16 @@ static void ovo_streaming_csc_host_impl( nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, nnz * sizeof(IndexT), cudaMemcpyHostToDevice, stream); - { - std::vector h_adj(sb_cols + 1); - for (int i = 0; i <= sb_cols; i++) - h_adj[i] = (int)(h_indptr[col + i] - ptr_start); - cudaMemcpy(buf.d_indptr, h_adj.data(), (sb_cols + 1) * sizeof(int), - cudaMemcpyHostToDevice); - } + int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.d_indptr, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); // ---- Cast to float32 for sort + accumulate stats in float64 ---- - ovr_cast_and_accumulate_sparse_kernel - <<>>( - buf.d_sparse_data_orig, buf.d_sparse_data_f32, - buf.d_sparse_indices, buf.d_indptr, d_stats_codes, - buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, - n_groups_stats, compute_sq_sums, compute_nnz); - CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); + launch_ovr_cast_and_accumulate_sparse( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, + buf.d_indptr, d_stats_codes, buf.d_group_sums, buf.d_group_sq_sums, + buf.d_group_nnz, sb_cols, n_groups_stats, compute_sq_sums, + compute_nnz, UTIL_BLOCK_SIZE, smem_cast, cast_use_gmem, stream); // ---- Extract ref from CSC via row_map, sort ---- cudaMemsetAsync(buf.ref_dense, 0, sb_ref_actual * sizeof(float), @@ -470,7 +481,7 @@ static void ovo_streaming_csr_host_impl( // Linux x86-64, but the API is the safe/portable way). InT* d_data_zc = nullptr; IndexT* d_indices_zc = nullptr; - { + if (full_nnz > 0) { cudaError_t e1 = cudaHostGetDevicePointer((void**)&d_data_zc, const_cast(h_data), 0); cudaError_t e2 = cudaHostGetDevicePointer( diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_bindings.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_bindings.cuh index fab780e9..72ae1938 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_bindings.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_bindings.cuh @@ -106,6 +106,52 @@ void register_bindings(nb::module_& m) { int64_t); #undef RSC_OVR_SPARSE_CSC_HOST_BINDING +#define RSC_OVR_SPARSE_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + host_array h_group_sizes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, int sub_batch_cols) { \ + ovr_sparse_csr_host_streaming_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), h_group_sizes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ + n_groups, compute_tie_corr, compute_sq_sums, compute_nnz, \ + sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, \ + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", float, int, int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_i64", float, int, + int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_idx64", float, int64_t, + int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_idx64_i64", float, + int64_t, int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64", double, int, + int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_i64", double, int, + int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_idx64", double, + int64_t, int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVR_SPARSE_CSR_HOST_BINDING + #define RSC_OVR_DENSE_HOST_BINDING(NAME, InT) \ m.def( \ NAME, \ diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_dense.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_dense.cuh index 039f6ee6..2adb5b7b 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_dense.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_dense.cuh @@ -217,12 +217,9 @@ static void ovr_streaming_dense_host_impl( bool use_gmem = false; size_t smem_rank = ovr_smem_config(n_groups, use_gmem); int tpb_cast = UTIL_BLOCK_SIZE; - size_t smem_cast = (size_t)n_groups * sizeof(double); - if (compute_nnz) { - smem_cast = (size_t)(3 * n_groups) * sizeof(double); - } else if (compute_sq_sums) { - smem_cast = (size_t)(2 * n_groups) * sizeof(double); - } + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_sq_sums, + compute_nnz, cast_use_gmem); // Pin only the host input. Outputs live on the device (caller-owned). HostRegisterGuard _pin_block(const_cast(h_block), @@ -242,12 +239,11 @@ static void ovr_streaming_dense_host_impl( sb_items * sizeof(InT), cudaMemcpyHostToDevice, stream); // Cast to float32 for sort + accumulate stats in float64 - ovr_cast_and_accumulate_dense_kernel - <<>>( - buf.d_block_orig, buf.d_block_f32, d_group_codes, - buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, n_rows, - sb_cols, n_groups, compute_sq_sums, compute_nnz); - CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_kernel); + launch_ovr_cast_and_accumulate_dense( + buf.d_block_orig, buf.d_block_f32, d_group_codes, buf.d_group_sums, + buf.d_group_sq_sums, buf.d_group_nnz, n_rows, sb_cols, n_groups, + compute_sq_sums, compute_nnz, tpb_cast, smem_cast, cast_use_gmem, + stream); // Fill segment offsets + row indices upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh index a94338b7..006002b9 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -1,16 +1,17 @@ #pragma once /** Count nonzeros per column from CSR. One thread per row. */ -__global__ void csr_col_histogram_kernel(const int* __restrict__ indices, - const int* __restrict__ indptr, +template +__global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, int* __restrict__ col_counts, int n_rows, int n_cols) { int row = blockIdx.x * blockDim.x + threadIdx.x; if (row >= n_rows) return; - int rs = indptr[row]; - int re = indptr[row + 1]; - for (int p = rs; p < re; ++p) { - int c = indices[p]; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; + for (IndptrT p = rs; p < re; ++p) { + int c = (int)indices[p]; if (c < n_cols) atomicAdd(&col_counts[c], 1); } } @@ -20,26 +21,27 @@ __global__ void csr_col_histogram_kernel(const int* __restrict__ indices, * write_pos[c - col_start] must be initialized to the prefix-sum offset * for column c. Each thread atomically claims a unique destination slot. */ +template __global__ void csr_scatter_to_csc_kernel( - const float* __restrict__ data, const int* __restrict__ indices, - const int* __restrict__ indptr, int* __restrict__ write_pos, - float* __restrict__ csc_vals, int* __restrict__ csc_row_idx, int n_rows, + const InT* __restrict__ data, const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, int* __restrict__ write_pos, + InT* __restrict__ csc_vals, int* __restrict__ csc_row_idx, int n_rows, int col_start, int col_stop) { int row = blockIdx.x * blockDim.x + threadIdx.x; if (row >= n_rows) return; - int rs = indptr[row]; - int re = indptr[row + 1]; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; // Binary search for col_start (overflow-safe midpoint) - int lo = rs, hi = re; + IndptrT lo = rs, hi = re; while (lo < hi) { - int m = lo + ((hi - lo) >> 1); + IndptrT m = lo + ((hi - lo) >> 1); if (indices[m] < col_start) lo = m + 1; else hi = m; } - for (int p = lo; p < re; ++p) { - int c = indices[p]; + for (IndptrT p = lo; p < re; ++p) { + int c = (int)indices[p]; if (c >= col_stop) break; int dest = atomicAdd(&write_pos[c - col_start], 1); csc_vals[dest] = data[p]; diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 725e2c41..0f74a2c8 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -109,12 +109,9 @@ static void ovr_sparse_csc_host_streaming_impl( int tpb = UTIL_BLOCK_SIZE; bool rank_use_gmem = false; size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); - size_t smem_cast = (size_t)n_groups * sizeof(double); - if (compute_nnz) { - smem_cast = (size_t)(3 * n_groups) * sizeof(double); - } else if (compute_sq_sums) { - smem_cast = (size_t)(2 * n_groups) * sizeof(double); - } + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_sq_sums, + compute_nnz, cast_use_gmem); // In gmem mode the sparse rank kernel accumulates into rank_sums directly // and needs a per-stream nz_count scratch buffer sized (n_groups, sb_cols). @@ -164,13 +161,12 @@ static void ovr_sparse_csc_host_streaming_impl( cudaMemcpyDeviceToDevice, stream); // Cast to float32 for sort + accumulate stats in float64 - ovr_cast_and_accumulate_sparse_kernel - <<>>( - buf.d_sparse_data_orig, buf.d_sparse_data_f32, - buf.d_sparse_indices, buf.d_seg_offsets, d_group_codes, - buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, - n_groups, compute_sq_sums, compute_nnz); - CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); + launch_ovr_cast_and_accumulate_sparse( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, + buf.d_seg_offsets, d_group_codes, buf.d_group_sums, + buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, n_groups, + compute_sq_sums, compute_nnz, tpb, smem_cast, cast_use_gmem, + stream); // CUB sort only stored nonzeros (float32 keys) if (batch_nnz > 0) { @@ -239,6 +235,284 @@ static void ovr_sparse_csc_host_streaming_impl( for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); } +// ============================================================================ +// Sparse-aware host-streaming CSR OVR pipeline. +// ============================================================================ + +/** + * Host CSR variant of the sparse OVR stream. + * + * The CSR input stays in host memory. We count columns once on the CPU, then + * use mapped pinned CSR arrays for bounded per-column-batch CSR->CSC scatter + * on the GPU. This avoids both a full host->device sparse upload and any + * whole-matrix CSR->CSC conversion. + */ +template +static void ovr_sparse_csr_host_streaming_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + RmmPool pool; + size_t total_nnz = (size_t)h_indptr[n_rows]; + + // ---- Phase 0: CPU planning in native CSR order ---- + std::vector h_col_counts(n_cols, 0); + for (int row = 0; row < n_rows; row++) { + IndptrT rs = h_indptr[row]; + IndptrT re = h_indptr[row + 1]; + for (IndptrT p = rs; p < re; ++p) { + int c = (int)h_indices[p]; + if (c >= 0 && c < n_cols) h_col_counts[c]++; + } + } + + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + size_t max_batch_nnz = 0; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + std::vector h_batch_nnz(n_batches); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb_cols = std::min(sub_batch_cols, n_cols - col_start); + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i < sb_cols; i++) + off[i + 1] = off[i] + h_col_counts[col_start + i]; + h_batch_nnz[b] = (size_t)off[sb_cols]; + if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; + } + + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // ---- Phase 1: allocate per-stream bounded work buffers ---- + size_t cub_temp_bytes = 0; + if (max_batch_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_batch_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_sq_sums, + compute_nnz, cast_use_gmem); + + int n_streams = N_STREAMS; + if (n_batches < n_streams) n_streams = n_batches; + + size_t per_stream_bytes = + max_batch_nnz * (sizeof(InT) + sizeof(float) + 2 * sizeof(int)) + + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double) + + sub_batch_cols * sizeof(double); + if (compute_sq_sums) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + if (compute_nnz) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + if (rank_use_gmem) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + + size_t free_mem = 0, total_mem = 0; + cudaMemGetInfo(&free_mem, &total_mem); + constexpr double MEM_BUDGET_FRAC = 0.8; + size_t budget = (size_t)(free_mem * MEM_BUDGET_FRAC); + while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + n_streams--; + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Pin the source CSR arrays as mapped memory. The scatter kernel reads + // only the requested column window from each row. + HostRegisterGuard pin_data; + HostRegisterGuard pin_indices; + InT* d_data_zc = nullptr; + IndexT* d_indices_zc = nullptr; + if (total_nnz > 0) { + pin_data = + HostRegisterGuard(const_cast(h_data), total_nnz * sizeof(InT), + cudaHostRegisterMapped); + pin_indices = HostRegisterGuard(const_cast(h_indices), + total_nnz * sizeof(IndexT), + cudaHostRegisterMapped); + cudaError_t e1 = cudaHostGetDevicePointer((void**)&d_data_zc, + const_cast(h_data), 0); + cudaError_t e2 = cudaHostGetDevicePointer( + (void**)&d_indices_zc, const_cast(h_indices), 0); + if (e1 != cudaSuccess || e2 != cudaSuccess) { + throw std::runtime_error( + std::string("cudaHostGetDevicePointer failed: ") + + cudaGetErrorString(e1 != cudaSuccess ? e1 : e2)); + } + } + + IndptrT* d_indptr_full = pool.alloc(n_rows + 1); + cudaMemcpy(d_indptr_full, h_indptr, (n_rows + 1) * sizeof(IndptrT), + cudaMemcpyHostToDevice); + + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + int scatter_blocks = (n_rows + tpb - 1) / tpb; + + struct StreamBuf { + int* col_offsets; + int* write_pos; + InT* csc_vals_orig; + float* csc_vals_f32; + int* csc_row_idx; + float* keys_out; + int* vals_out; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* sub_group_sums; + double* sub_group_sq_sums; + double* sub_group_nnz; + double* d_nz_scratch; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].col_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].write_pos = pool.alloc(sub_batch_cols); + bufs[s].csc_vals_orig = pool.alloc(max_batch_nnz); + bufs[s].csc_vals_f32 = pool.alloc(max_batch_nnz); + bufs[s].csc_row_idx = pool.alloc(max_batch_nnz); + bufs[s].keys_out = pool.alloc(max_batch_nnz); + bufs[s].vals_out = pool.alloc(max_batch_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].sub_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_group_sq_sums = + compute_sq_sums + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].sub_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + // ---- Phase 2: bounded CSR->CSC scatter + GPU rank batches ---- + int col = 0; + for (int b = 0; b < n_batches; b++) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = b % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + int batch_nnz = (int)h_batch_nnz[b]; + + int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + csr_scatter_to_csc_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, buf.write_pos, + buf.csc_vals_orig, buf.csc_row_idx, n_rows, col, + col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + } + + launch_ovr_cast_and_accumulate_sparse( + buf.csc_vals_orig, buf.csc_vals_f32, buf.csc_row_idx, + buf.col_offsets, d_group_codes, buf.sub_group_sums, + buf.sub_group_sq_sums, buf.sub_group_nnz, sb_cols, n_groups, + compute_sq_sums, compute_nnz, tpb, smem_cast, cast_use_gmem, + stream); + + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.csc_vals_f32, buf.keys_out, + buf.csc_row_idx, buf.vals_out, batch_nnz, sb_cols, + buf.col_offsets, buf.col_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.col_offsets, d_group_codes, + d_group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, + buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, + rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(d_tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.sub_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.sub_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.sub_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse host CSR streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + // ============================================================================ // Sparse-aware CSC OVR streaming (sort only stored nonzeros) // ============================================================================ diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index 866121e1..75469fcb 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -107,10 +107,10 @@ def rank_genes_groups( layer Key from `adata.layers` whose value will be used to perform tests on. chunk_size - Number of genes to process at once for `'wilcoxon_binned'`. - The default is sized dynamically based on ``n_groups`` and - ``n_bins`` to keep histogram memory stable. - Ignored for other methods. + Number of genes to process at once for memory-bounded Wilcoxon + paths. For `'wilcoxon_binned'`, the default is sized dynamically + based on ``n_groups`` and ``n_bins`` to keep histogram memory + stable. pre_load Pre-load the data into GPU memory. Used only for `'wilcoxon'`. n_bins diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index 03a24e24..ce663b80 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -214,16 +214,18 @@ def _basic_stats(self) -> None: # Compute rest statistics if reference='rest' if self.ireference is None: - n_rest = n.sum() - n - means_rest = (sums.sum(axis=0) - sums) / n_rest - rest_ss = (sq_sums.sum(axis=0) - sq_sums) - n_rest * means_rest**2 + n_rest = cp.float64(self.X.shape[0]) - n + total_sums = result["sum"].sum(axis=0, keepdims=True) + total_sq_sums = result["sq_sum"].sum(axis=0, keepdims=True) + means_rest = (total_sums - sums) / n_rest + rest_ss = (total_sq_sums - sq_sums) - n_rest * means_rest**2 vars_rest = cp.maximum(rest_ss / cp.maximum(n_rest - 1, 1), 0) self.means_rest = cp.asnumpy(means_rest) self.vars_rest = cp.asnumpy(vars_rest) if self.comp_pts: - total_count = (pts * n).sum(axis=0) + total_count = result["count_nonzero"].sum(axis=0, keepdims=True) self.pts_rest = cp.asnumpy((total_count - pts * n) / n_rest) else: self.pts_rest = None @@ -250,6 +252,7 @@ def wilcoxon( *, tie_correct: bool, use_continuity: bool = False, + chunk_size: int | None = None, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" from ._wilcoxon import wilcoxon @@ -258,6 +261,7 @@ def wilcoxon( self, tie_correct=tie_correct, use_continuity=use_continuity, + chunk_size=chunk_size, ) def wilcoxon_binned( @@ -318,6 +322,7 @@ def compute_statistics( test_results = self.wilcoxon( tie_correct=tie_correct, use_continuity=use_continuity, + chunk_size=chunk_size, ) elif method == "wilcoxon_binned": test_results = self.wilcoxon_binned( diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 45833d78..5fcee8d2 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -25,6 +25,16 @@ # --------------------------------------------------------------------------- +def _resolve_chunk_size(chunk_size: int | None, n_total_genes: int) -> int: + if chunk_size is None: + return max(1, n_total_genes) + chunk_width = int(chunk_size) + if chunk_width <= 0: + msg = "`chunk_size` must be a positive integer." + raise ValueError(msg) + return min(chunk_width, max(1, n_total_genes)) + + def _fill_basic_stats_from_accumulators( rg: _RankGenes, group_sums: cp.ndarray, @@ -34,6 +44,9 @@ def _fill_basic_stats_from_accumulators( *, n_cells: int, compute_vars: bool = True, + total_sums: cp.ndarray | None = None, + total_sq_sums: cp.ndarray | None = None, + total_nnz: cp.ndarray | None = None, ) -> None: """Populate rg.means/vars/pts (+ *_rest) from streamed accumulators. @@ -57,20 +70,23 @@ def _fill_basic_stats_from_accumulators( if rg.ireference is None: n_rest = cp.float64(n_cells) - n - total_sum = group_sums.sum(axis=0, keepdims=True) - rest_sums = total_sum - group_sums + if total_sums is None: + total_sums = group_sums.sum(axis=0, keepdims=True) + rest_sums = total_sums - group_sums rest_means = rest_sums / n_rest rg.means_rest = cp.asnumpy(rest_means) if compute_vars: - total_sq_sum = group_sq_sums.sum(axis=0, keepdims=True) - rest_ss = (total_sq_sum - group_sq_sums) - n_rest * rest_means**2 + if total_sq_sums is None: + total_sq_sums = group_sq_sums.sum(axis=0, keepdims=True) + rest_ss = (total_sq_sums - group_sq_sums) - n_rest * rest_means**2 rg.vars_rest = cp.asnumpy( cp.maximum(rest_ss / cp.maximum(n_rest - 1, 1), 0) ) else: rg.vars_rest = np.zeros_like(rg.means_rest) if rg.comp_pts: - total_nnz = group_nnz.sum(axis=0, keepdims=True) + if total_nnz is None: + total_nnz = group_nnz.sum(axis=0, keepdims=True) rg.pts_rest = cp.asnumpy((total_nnz - group_nnz) / n_rest) else: rg.pts_rest = None @@ -139,7 +155,8 @@ def _to_gpu_native(X, n_rows: int, n_cols: int): if cpsp.issparse(X): return X - # Host sparse → GPU sparse, same format. + # Host sparse → GPU sparse, same format. Wilcoxon kernels are native CSR + # or CSC only; do not hide a whole-matrix sparse format conversion here. # Downcast indices to int32 on host before transfer (column indices # always fit in int32; scipy may use int64 when nnz > 2^31). if isinstance(X, sp.spmatrix | sp.sparray): @@ -152,7 +169,12 @@ def _to_gpu_native(X, n_rows: int, n_cols: int): ), shape=(n_rows, n_cols), ) - csr = X.tocsr() if X.format != "csr" else X + if X.format != "csr": + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." + ) + csr = X return cpsp.csr_matrix( ( cp.asarray(csr.data), @@ -169,6 +191,128 @@ def _to_gpu_native(X, n_rows: int, n_cols: int): raise TypeError(f"Unsupported matrix type: {type(X)}") +def _host_sparse_fn_and_arrays(module, base_name: str, X, *, support_idx64: bool): + """Select host sparse binding and dtype-normalized arrays.""" + is_f64 = X.data.dtype == np.float64 + is_idx64 = support_idx64 and X.indices.dtype == np.int64 + is_i64 = X.indptr.dtype == np.int64 + suffix = "" + if is_f64: + suffix += "_f64" + if is_idx64: + suffix += "_idx64" + if is_i64: + suffix += "_i64" + fn = getattr(module, base_name + suffix) + data_arr = X.data if is_f64 else X.data.astype(np.float32, copy=False) + indices_arr = X.indices if is_idx64 else X.indices.astype(np.int32, copy=False) + return fn, data_arr, indices_arr + + +def _column_totals_for_host_matrix( + X, *, compute_sq_sums: bool, compute_nnz: bool +) -> tuple[cp.ndarray, cp.ndarray | None, cp.ndarray | None]: + """Compute all-cell column totals without changing sparse format.""" + n_cols = X.shape[1] + + if isinstance(X, sp.spmatrix | sp.sparray): + data = np.asarray(X.data) + values = data.astype(np.float64, copy=False) + + if X.format == "csc": + indptr = np.asarray(X.indptr) + counts = np.diff(indptr) + nonempty = counts > 0 + starts = indptr[:-1][nonempty] + + sums = np.zeros(n_cols, dtype=np.float64) + if starts.size: + sums[nonempty] = np.add.reduceat(values, starts) + + sq_sums = None + if compute_sq_sums: + sq_sums = np.zeros(n_cols, dtype=np.float64) + if starts.size: + sq_sums[nonempty] = np.add.reduceat(values * values, starts) + + nnz = None + if compute_nnz: + nnz = np.zeros(n_cols, dtype=np.float64) + if starts.size: + nnz[nonempty] = np.add.reduceat( + (data != 0).astype(np.float64, copy=False), starts + ) + elif X.format == "csr": + indices = np.asarray(X.indices, dtype=np.intp) + sums = np.bincount(indices, weights=values, minlength=n_cols).astype( + np.float64, copy=False + ) + + sq_sums = ( + np.bincount(indices, weights=values * values, minlength=n_cols).astype( + np.float64, copy=False + ) + if compute_sq_sums + else None + ) + nnz = ( + np.bincount( + indices, + weights=(data != 0).astype(np.float64, copy=False), + minlength=n_cols, + ).astype(np.float64, copy=False) + if compute_nnz + else None + ) + else: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." + ) + elif isinstance(X, np.ndarray): + sums = np.asarray(X.sum(axis=0, dtype=np.float64), dtype=np.float64) + sq_sums = ( + np.asarray(np.square(X, dtype=np.float64).sum(axis=0), dtype=np.float64) + if compute_sq_sums + else None + ) + nnz = ( + np.asarray(np.count_nonzero(X, axis=0), dtype=np.float64) + if compute_nnz + else None + ) + else: + raise TypeError(f"Unsupported host matrix type: {type(X)}") + + total_sums = cp.asarray(sums.reshape(1, n_cols), dtype=cp.float64) + total_sq_sums = ( + cp.asarray(sq_sums.reshape(1, n_cols), dtype=cp.float64) + if sq_sums is not None + else None + ) + total_nnz = ( + cp.asarray(nnz.reshape(1, n_cols), dtype=cp.float64) + if nnz is not None + else None + ) + return total_sums, total_sq_sums, total_nnz + + +def _host_ovr_totals_if_needed( + X, + group_codes: np.ndarray, + n_groups: int, + *, + compute_sq_sums: bool, + compute_nnz: bool, +) -> tuple[cp.ndarray | None, cp.ndarray | None, cp.ndarray | None]: + if not np.any(group_codes == n_groups): + return None, None, None + return _column_totals_for_host_matrix( + X, compute_sq_sums=compute_sq_sums, compute_nnz=compute_nnz + ) + + def _extract_dense_block( X, row_ids: cp.ndarray | None, @@ -298,6 +442,7 @@ def wilcoxon( *, tie_correct: bool, use_continuity: bool = False, + chunk_size: int | None = None, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" X = rg.X @@ -316,6 +461,7 @@ def wilcoxon( group_sizes, tie_correct=tie_correct, use_continuity=use_continuity, + chunk_size=chunk_size, ) return _wilcoxon_vs_rest( rg, @@ -368,9 +514,10 @@ def _wilcoxon_vs_rest( # Determine host-streaming eligibility BEFORE transferring host_csc = isinstance(X, sp.spmatrix | sp.sparray) and X.format == "csc" + host_csr = isinstance(X, sp.spmatrix | sp.sparray) and X.format == "csr" host_dense = isinstance(X, np.ndarray) - if host_csc or host_dense: + if host_csc or host_csr or host_dense: # Host-streaming: sort+rank stays on host→GPU per sub-batch. The # kernel also emits per-group sum, sum-of-squares, and nonzero # counts into caller-provided CuPy buffers, so means/vars/pts can @@ -395,20 +542,12 @@ def _wilcoxon_vs_rest( # Native host dtype is preserved and uploaded once per sub-batch; # a pre-sort kernel casts to float32 for the sort keys while # accumulating stats in float64 from the original values. - is_f64 = X.data.dtype == np.float64 - is_i64 = X.indptr.dtype == np.int64 - if is_f64 and is_i64: - _csc_host_fn = _ovr.ovr_sparse_csc_host_f64_i64 - elif is_f64: - _csc_host_fn = _ovr.ovr_sparse_csc_host_f64 - elif is_i64: - _csc_host_fn = _ovr.ovr_sparse_csc_host_i64 - else: - _csc_host_fn = _ovr.ovr_sparse_csc_host - data_arr = X.data if is_f64 else X.data.astype(np.float32, copy=False) + _csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _ovr, "ovr_sparse_csc_host", X, support_idx64=False + ) _csc_host_fn( data_arr, - X.indices.astype(np.int32, copy=False), + indices_arr, X.indptr, group_codes, group_sizes_np, @@ -425,6 +564,34 @@ def _wilcoxon_vs_rest( compute_nnz=compute_nnz, sub_batch_cols=STREAMING_SUB_BATCH, ) + elif host_csr: + group_sizes_np = group_sizes.astype(np.float64, copy=False) + csr = X + if not csr.has_sorted_indices: + csr = csr.copy() + csr.sort_indices() + _csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _ovr, "ovr_sparse_csr_host", csr, support_idx64=True + ) + _csr_host_fn( + data_arr, + indices_arr, + csr.indptr, + group_codes, + group_sizes_np, + rank_sums, + tie_corr, + group_sums, + group_sq_sums, + group_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=STREAMING_SUB_BATCH, + ) else: is_f64 = X.dtype == np.float64 _dense_host_fn = ( @@ -451,6 +618,13 @@ def _wilcoxon_vs_rest( ) if rg._compute_stats_in_chunks: + total_sums, total_sq_sums, total_nnz = _host_ovr_totals_if_needed( + X, + group_codes, + n_groups, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + ) _fill_basic_stats_from_accumulators( rg, group_sums, @@ -459,9 +633,12 @@ def _wilcoxon_vs_rest( group_sizes.astype(np.float64, copy=False), n_cells=n_cells, compute_vars=compute_vars, + total_sums=total_sums, + total_sq_sums=total_sq_sums, + total_nnz=total_nnz, ) else: - # GPU data or host CSR → transfer to GPU, use GPU kernels + # GPU data → use native GPU kernels. X_gpu = _to_gpu_native(X, n_cells, n_total_genes) if rg._compute_stats_in_chunks: @@ -563,6 +740,7 @@ def _wilcoxon_with_reference( *, tie_correct: bool, use_continuity: bool, + chunk_size: int | None, ) -> list[tuple[int, NDArray, NDArray]]: """Wilcoxon test: each group vs a specific reference group. @@ -684,29 +862,8 @@ def _wilcoxon_with_reference( ) if host_sparse and X.format == "csc": - is_f64 = X.data.dtype == np.float64 - is_idx64 = X.indices.dtype == np.int64 - is_i64 = X.indptr.dtype == np.int64 - if is_f64: - if is_idx64 and is_i64: - _csc_host_fn = _wc.ovo_streaming_csc_host_f64_idx64_i64 - elif is_idx64: - _csc_host_fn = _wc.ovo_streaming_csc_host_f64_idx64 - elif is_i64: - _csc_host_fn = _wc.ovo_streaming_csc_host_f64_i64 - else: - _csc_host_fn = _wc.ovo_streaming_csc_host_f64 - elif is_idx64 and is_i64: - _csc_host_fn = _wc.ovo_streaming_csc_host_idx64_i64 - elif is_idx64: - _csc_host_fn = _wc.ovo_streaming_csc_host_idx64 - elif is_i64: - _csc_host_fn = _wc.ovo_streaming_csc_host_i64 - else: - _csc_host_fn = _wc.ovo_streaming_csc_host - data_arr = X.data if is_f64 else X.data.astype(np.float32, copy=False) - indices_arr = ( - X.indices if is_idx64 else X.indices.astype(np.int32, copy=False) + _csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wc, "ovo_streaming_csc_host", X, support_idx64=True ) _csc_host_fn( data_arr, @@ -732,37 +889,16 @@ def _wilcoxon_with_reference( compute_nnz=compute_nnz, sub_batch_cols=STREAMING_SUB_BATCH, ) - elif host_sparse: - csr = X.tocsr() if X.format != "csr" else X + elif host_sparse and X.format == "csr": + csr = X if not csr.has_sorted_indices: csr = csr.copy() csr.sort_indices() - is_f64 = csr.data.dtype == np.float64 - is_idx64 = csr.indices.dtype == np.int64 - is_i64 = csr.indptr.dtype == np.int64 # Zero-copy mapped: pin full CSR, upload indptr + row_ids, GPU # kernels gather per-pack rows via UVA reads. - if is_f64: - if is_idx64 and is_i64: - _csr_host_fn = _wc.ovo_streaming_csr_host_f64_idx64_i64 - elif is_idx64: - _csr_host_fn = _wc.ovo_streaming_csr_host_f64_idx64 - elif is_i64: - _csr_host_fn = _wc.ovo_streaming_csr_host_f64_i64 - else: - _csr_host_fn = _wc.ovo_streaming_csr_host_f64 - elif is_idx64 and is_i64: - _csr_host_fn = _wc.ovo_streaming_csr_host_idx64_i64 - elif is_idx64: - _csr_host_fn = _wc.ovo_streaming_csr_host_idx64 - elif is_i64: - _csr_host_fn = _wc.ovo_streaming_csr_host_i64 - else: - _csr_host_fn = _wc.ovo_streaming_csr_host - data_arr = csr.data if is_f64 else csr.data.astype(np.float32, copy=False) - indices_arr = ( - csr.indices if is_idx64 else csr.indices.astype(np.int32, copy=False) + _csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wc, "ovo_streaming_csr_host", csr, support_idx64=True ) _csr_host_fn( data_arr, @@ -787,6 +923,11 @@ def _wilcoxon_with_reference( compute_nnz=compute_nnz, sub_batch_cols=STREAMING_SUB_BATCH, ) + elif host_sparse: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." + ) else: is_f64 = X.dtype == np.float64 _dense_host_fn = ( @@ -869,9 +1010,9 @@ def _wilcoxon_with_reference( compute_tie_corr=tie_correct, sub_batch_cols=STREAMING_SUB_BATCH, ) - elif cpsp.issparse(X_gpu): + elif cpsp.isspmatrix_csr(X_gpu): # CSR-native: extract ref/grp rows directly - csr_gpu = X_gpu.tocsr() if not cpsp.isspmatrix_csr(X_gpu) else X_gpu + csr_gpu = X_gpu if not csr_gpu.has_sorted_indices: csr_gpu = csr_gpu.copy() csr_gpu.sort_indices() @@ -895,33 +1036,50 @@ def _wilcoxon_with_reference( compute_tie_corr=tie_correct, sub_batch_cols=STREAMING_SUB_BATCH, ) - else: - # Dense: extract blocks, sort, stream - ref_block = _extract_dense_block(X_gpu, ref_row_ids_gpu, 0, n_total_genes) - grp_block = _extract_dense_block( - X_gpu, all_grp_row_ids_gpu, 0, n_total_genes - ) - ref_sorted = _segmented_sort_columns( - ref_block, - np.array([0, n_ref], dtype=np.int32), - n_ref, - n_total_genes, - 1, - ) - grp_f32 = cp.asfortranarray(grp_block.astype(cp.float32, copy=False)) - cp.cuda.get_current_stream().synchronize() - _wc.ovo_streaming( - ref_sorted, - grp_f32, - grp_offsets_gpu, - rank_sums, - tie_corr_arr, - n_ref=n_ref, - n_all_grp=n_all_grp, - n_cols=n_total_genes, - n_groups=n_test, - compute_tie_corr=tie_correct, + elif cpsp.issparse(X_gpu): + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + "full-matrix GPU sparse conversion." ) + else: + # Dense device data is already resident, but extracting all genes + # for all reference/test rows can still blow up memory. Preserve + # the public chunk_size escape hatch by materializing bounded + # column blocks and stitching the CUDA outputs together. + dense_chunk = _resolve_chunk_size(chunk_size, n_total_genes) + for start in range(0, n_total_genes, dense_chunk): + stop = min(start + dense_chunk, n_total_genes) + sb_cols = stop - start + ref_block = _extract_dense_block(X_gpu, ref_row_ids_gpu, start, stop) + grp_block = _extract_dense_block( + X_gpu, all_grp_row_ids_gpu, start, stop + ) + ref_sorted = _segmented_sort_columns( + ref_block, + np.array([0, n_ref], dtype=np.int32), + n_ref, + sb_cols, + 1, + ) + grp_f32 = cp.asfortranarray(grp_block.astype(cp.float32, copy=False)) + sub_rank_sums = cp.empty((n_test, sb_cols), dtype=cp.float64) + sub_tie_corr = cp.empty((n_test, sb_cols), dtype=cp.float64) + cp.cuda.get_current_stream().synchronize() + _wc.ovo_streaming( + ref_sorted, + grp_f32, + grp_offsets_gpu, + sub_rank_sums, + sub_tie_corr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=sb_cols, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + rank_sums[:, start:stop] = sub_rank_sums + tie_corr_arr[:, start:stop] = sub_tie_corr # ---- z-scores & p-values (vectorised) ---- n_combined = test_sizes + n_ref diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index e194d79f..aba9ac4a 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -14,6 +14,10 @@ def _to_format(X_dense, fmt): """Convert dense numpy array to the specified format.""" + if fmt == "numpy_dense": + return np.asarray(X_dense) + if fmt == "scipy_csr": + return sp.csr_matrix(X_dense) if fmt == "scipy_csc": return sp.csc_matrix(X_dense) if fmt == "cupy_dense": @@ -162,6 +166,118 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference): assert np.all(adjusted <= 1.0) +@pytest.mark.parametrize( + "fmt", + [ + pytest.param("scipy_csr", id="host_csr"), + pytest.param("scipy_csc", id="host_csc"), + pytest.param("cupy_dense", id="device_dense"), + ], +) +def test_wilcoxon_subset_rest_stats_match_scanpy(fmt): + """groups=... with reference='rest' must use all other cells for stats.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=160) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "groups": ["0", "2"], + "reference": "rest", + "pts": True, + "n_genes": 6, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + + for key in ("pts", "pts_rest"): + gpu_pts = gpu_result[key] + cpu_pts = cpu_result[key] + for col in gpu_pts.columns: + np.testing.assert_allclose( + gpu_pts[col].values, cpu_pts[col].values, rtol=1e-13, atol=1e-15 + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc"]) +def test_wilcoxon_zero_nnz_host_sparse_does_not_crash(reference, fmt): + obs = pd.DataFrame( + { + "group": pd.Categorical( + ["0"] * 4 + ["1"] * 4 + ["2"] * 4, + categories=["0", "1", "2"], + ) + } + ) + adata = sc.AnnData( + X=_to_format(np.zeros((12, 5), dtype=np.float32), fmt), + obs=obs, + var=pd.DataFrame(index=[f"g{i}" for i in range(5)]), + ) + + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference=reference, + pts=True, + ) + + result = adata.uns["rank_genes_groups"] + for field in ("scores", "pvals"): + for group in result[field].dtype.names: + assert np.all(np.isfinite(np.asarray(result[field][group], dtype=float))) + + +def test_wilcoxon_dense_ovo_chunk_size_matches_unchunked(): + np.random.seed(42) + base = sc.datasets.blobs(n_variables=9, n_centers=3, n_observations=120) + base.obs["blobs"] = base.obs["blobs"].astype("category") + unchunked = base.copy() + chunked = base.copy() + unchunked.X = cp.asarray(unchunked.X) + chunked.X = cp.asarray(chunked.X) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": "1", + "tie_correct": True, + "n_genes": 9, + } + rsc.tl.rank_genes_groups(unchunked, **kw) + rsc.tl.rank_genes_groups(chunked, **kw, chunk_size=2) + + for field in ("scores", "pvals", "pvals_adj", "logfoldchanges"): + for group in unchunked.uns["rank_genes_groups"][field].dtype.names: + np.testing.assert_allclose( + np.asarray(unchunked.uns["rank_genes_groups"][field][group], float), + np.asarray(chunked.uns["rank_genes_groups"][field][group], float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + + @pytest.mark.parametrize( "reference_before,reference_after", [("rest", "rest"), ("1", "One")],