diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 3f2e4447..c506c972 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -74,11 +74,16 @@ jobs: 'name = "rapids-singlecell"', f'name = "rapids-singlecell-cu{cuda}"', ) - # Rename matching extra to "rapids", remove the other + # Rename matching extra to "rapids", remove the other CUDA extra text = text.replace(f'rapids-cu{cuda} =', 'rapids =') - # Remove the other CUDA extra line entirely - lines = text.splitlines(keepends=True) - text = "".join(l for l in lines if f'rapids-cu{other}' not in l) + # Remove the other CUDA extra (handles multi-line TOML arrays) + import re + text = re.sub( + rf'^rapids-cu{other}\s*=\s*\[.*?\]\s*\n', + '', + text, + flags=re.MULTILINE | re.DOTALL, + ) # Set CUDA architectures (replace "native" with CI target archs) text = text.replace( @@ -112,14 +117,31 @@ jobs: CIBW_ENVIRONMENT_PASS_LINUX: SETUPTOOLS_SCM_PRETEND_VERSION CIBW_ENVIRONMENT: > CUDA_PATH=/usr/local/cuda - LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH PATH=/usr/local/cuda/bin:$PATH CIBW_BEFORE_BUILD: > python -m pip install -U pip scikit-build-core cmake ninja nanobind + librmm-cu${{ matrix.cuda_major }} && + RMM_ROOT=$(python -c "import librmm;print(librmm.__path__[0])") && + LOG_ROOT=$(python -c "import rapids_logger;print(rapids_logger.__path__[0])") && + echo "[rsc-build] librmm=$RMM_ROOT" && + echo "[rsc-build] rapids_logger=$LOG_ROOT" && + ln -sf "$RMM_ROOT/lib64/librmm.so" /usr/local/lib/librmm.so && + ln -sf "$LOG_ROOT/lib64/librapids_logger.so" /usr/local/lib/librapids_logger.so && + ldconfig && + python -c "import librmm;print(librmm.__path__[0])" > /tmp/.librmm_dir && + echo "[rsc-build] marker=$(cat /tmp/.librmm_dir)" CIBW_TEST_SKIP: "*" CIBW_TEST_COMMAND: "" - CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} -w {dest_dir} {wheel}" + CIBW_REPAIR_WHEEL_COMMAND: > + auditwheel repair + --exclude libcublas.so.${{ matrix.cuda_major }} + --exclude libcublasLt.so.${{ matrix.cuda_major }} + --exclude libcudart.so.${{ matrix.cuda_major }} + --exclude librmm.so + --exclude librapids_logger.so + -w {dest_dir} {wheel} + && pipx run abi3audit --strict --report {wheel} CIBW_BUILD_VERBOSITY: "1" - uses: actions/upload-artifact@v4 diff --git a/.gitignore b/.gitignore index c0e83438..2735c080 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__/ /data/ test-data/ .vscode/ +.codex # Distribution / packaging /dist/ @@ -50,3 +51,4 @@ CLAUDE.md # tmp_scripts tmp_scripts/ +benchmarks/ diff --git a/CMakeLists.txt b/CMakeLists.txt index cacf9849..8e987b11 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,79 @@ if (RSC_BUILD_EXTENSIONS) find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT}) find_package(nanobind CONFIG REQUIRED) find_package(CUDAToolkit REQUIRED) + + # Find librmm cmake config. + # Searches all plausible locations: env vars, Python prefix, base_prefix + # (survives build isolation), and CI marker files. Works with conda, pip, + # pixi, uv, hatch, and cibuildwheel. + set(_env_roots "") + # Explicit override + if(DEFINED ENV{LIBRMM_DIR}) + file(GLOB _librmm_hints "$ENV{LIBRMM_DIR}/lib*/cmake" + "$ENV{LIBRMM_DIR}/../rapids_logger/lib*/cmake") + list(APPEND CMAKE_PREFIX_PATH ${_librmm_hints}) + endif() + # Environment managers + foreach(_var CONDA_PREFIX VIRTUAL_ENV PIXI_PROJECT_ROOT) + if(DEFINED ENV{${_var}}) + list(APPEND _env_roots "$ENV{${_var}}") + endif() + endforeach() + # Python prefix, base_prefix, and real executable's env root + foreach(_attr prefix base_prefix) + execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import sys; print(sys.${_attr})" + OUTPUT_VARIABLE _pp OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) + if(_pp) + list(APPEND _env_roots "${_pp}") + endif() + endforeach() + # Resolve symlinks to find the real Python env (works through build isolation) + execute_process( + COMMAND "${Python_EXECUTABLE}" -c + "import sys,pathlib; print(pathlib.Path(sys.executable).resolve().parents[1])" + OUTPUT_VARIABLE _real_prefix OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) + if(_real_prefix) + list(APPEND _env_roots "${_real_prefix}") + endif() + # Direct site-packages search — the most reliable way to find pip-installed + # librmm regardless of venv nesting depth + execute_process( + COMMAND "${Python_EXECUTABLE}" -c + "import site; print(';'.join(site.getsitepackages()))" + OUTPUT_VARIABLE _site_paths OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) + if(_site_paths) + foreach(_sp ${_site_paths}) + file(GLOB _sp_hints "${_sp}/librmm/lib*/cmake" + "${_sp}/rapids_logger/lib*/cmake") + foreach(_h ${_sp_hints}) + get_filename_component(_dir "${_h}" DIRECTORY) + list(APPEND CMAKE_PREFIX_PATH "${_dir}") + endforeach() + endforeach() + endif() + # CI/cibuildwheel: CIBW_BEFORE_BUILD writes the librmm path to a marker file + if(EXISTS "/tmp/.librmm_dir") + file(READ "/tmp/.librmm_dir" _rmm_marker) + string(STRIP "${_rmm_marker}" _rmm_marker) + file(GLOB _marker_hints "${_rmm_marker}/lib*/cmake" + "${_rmm_marker}/../rapids_logger/lib*/cmake") + list(APPEND CMAKE_PREFIX_PATH ${_marker_hints}) + endif() + # Search each root for rmm + rapids_logger cmake configs + list(REMOVE_DUPLICATES _env_roots) + foreach(_root ${_env_roots}) + file(GLOB _hints "${_root}/lib/cmake/rmm" + "${_root}/lib/python*/site-packages/librmm/lib*/cmake/rmm" + "${_root}/lib/python*/site-packages/rapids_logger/lib*/cmake/rapids_logger") + foreach(_h ${_hints}) + get_filename_component(_dir "${_h}" DIRECTORY) + list(APPEND CMAKE_PREFIX_PATH "${_dir}") + endforeach() + endforeach() + message(STATUS "rmm search roots: ${_env_roots}") + message(STATUS "rmm CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}") + find_package(rmm CONFIG REQUIRED) message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") else() message(STATUS "RSC_BUILD_EXTENSIONS=OFF -> skipping compiled extensions for docs") @@ -84,7 +157,8 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_edistance_cuda src/rapids_singlecell/_cuda/edistance/edistance.cu) add_nb_cuda_module(_hvg_cuda src/rapids_singlecell/_cuda/hvg/hvg.cu) add_nb_cuda_module(_kde_cuda src/rapids_singlecell/_cuda/kde/kde.cu) - add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) + add_nb_cuda_module(_wilcoxon_ovr_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu) + add_nb_cuda_module(_wilcoxon_ovo_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu) # Harmony CUDA modules add_nb_cuda_module(_harmony_scatter_cuda src/rapids_singlecell/_cuda/harmony/scatter/scatter.cu) add_nb_cuda_module(_harmony_outer_cuda src/rapids_singlecell/_cuda/harmony/outer/outer.cu) @@ -100,4 +174,8 @@ if (RSC_BUILD_EXTENSIONS) target_link_libraries(_harmony_correction_batched_cuda PRIVATE CUDA::cublas) # Wilcoxon binned histogram CUDA module add_nb_cuda_module(_wilcoxon_binned_cuda src/rapids_singlecell/_cuda/wilcoxon_binned/wilcoxon_binned.cu) + if(rmm_FOUND) + target_link_libraries(_wilcoxon_ovr_cuda PRIVATE rmm::rmm) + target_link_libraries(_wilcoxon_ovo_cuda PRIVATE rmm::rmm) + endif() endif() diff --git a/notebooks b/notebooks index 4cdaa44f..e5c97b34 160000 --- a/notebooks +++ b/notebooks @@ -1 +1 @@ -Subproject commit 4cdaa44fbd93b6f812fc8d2c72b89180ef92047d +Subproject commit e5c97b34f4acbf919fb3118c987cc5893e5b5fdf diff --git a/pyproject.toml b/pyproject.toml index c38e1d00..3777740e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,10 @@ requires = [ "scikit-build-core>=0.10", "nanobind>=2.0.0", "setuptools-scm>=8", + # librmm headers are needed at build time for the Wilcoxon CUDA kernels. + # The headers are identical across cu12/cu13; the runtime .so is loaded + # from the installed RAPIDS package. + "librmm-cu12>=25.10", ] build-backend = "scikit_build_core.build" @@ -32,8 +36,22 @@ dependencies = [ ] [project.optional-dependencies] -rapids-cu13 = [ "cupy-cuda13x", "cudf-cu13>=25.10", "cuml-cu13>=25.10", "cugraph-cu13>=25.10", "cuvs-cu13>=25.10" ] -rapids-cu12 = [ "cupy-cuda12x", "cudf-cu12>=25.10", "cuml-cu12>=25.10", "cugraph-cu12>=25.10", "cuvs-cu12>=25.10" ] +rapids-cu13 = [ + "cupy-cuda13x", + "librmm-cu13>=25.10", + "cudf-cu13>=25.10", + "cuml-cu13>=25.10", + "cugraph-cu13>=25.10", + "cuvs-cu13>=25.10", +] +rapids-cu12 = [ + "cupy-cuda12x", + "librmm-cu12>=25.10", + "cudf-cu12>=25.10", + "cuml-cu12>=25.10", + "cugraph-cu12>=25.10", + "cuvs-cu12>=25.10", +] doc = [ "sphinx>=4.5.0", diff --git a/src/rapids_singlecell/_cuda/__init__.py b/src/rapids_singlecell/_cuda/__init__.py index 35e82a0d..741178d6 100644 --- a/src/rapids_singlecell/_cuda/__init__.py +++ b/src/rapids_singlecell/_cuda/__init__.py @@ -13,6 +13,16 @@ import importlib +# Pre-load librmm.so + deps so the dynamic linker can resolve them when +# our nanobind extensions (which link rmm) are imported. This is the same +# pattern used by cuml, cuvs, and other RAPIDS packages. +try: + import librmm + + librmm.load_library() +except (ImportError, OSError): + pass + __all__ = [ "_aggr_cuda", "_aucell_cuda", @@ -44,7 +54,8 @@ "_sparse2dense_cuda", "_spca_cuda", "_wilcoxon_binned_cuda", - "_wilcoxon_cuda", + "_wilcoxon_ovo_cuda", + "_wilcoxon_ovr_cuda", ] diff --git a/src/rapids_singlecell/_cuda/cooc/kernels_cooc.cuh b/src/rapids_singlecell/_cuda/cooc/kernels_cooc.cuh index 4786647d..d5f9b553 100644 --- a/src/rapids_singlecell/_cuda/cooc/kernels_cooc.cuh +++ b/src/rapids_singlecell/_cuda/cooc/kernels_cooc.cuh @@ -353,7 +353,7 @@ __global__ void occur_count_kernel_csr_catpairs_tiled( // Binary search for threshold bin int lo = 0, hi = l_val; while (lo < hi) { - int mid = (lo + hi) >> 1; + int mid = lo + ((hi - lo) >> 1); if (dist_sq <= thresholds[mid]) hi = mid; else diff --git a/src/rapids_singlecell/_cuda/nb_types.h b/src/rapids_singlecell/_cuda/nb_types.h index 905e1e07..eb343815 100644 --- a/src/rapids_singlecell/_cuda/nb_types.h +++ b/src/rapids_singlecell/_cuda/nb_types.h @@ -42,6 +42,13 @@ using gpu_array = nb::ndarray; template using gpu_array_contig = nb::ndarray; +// Host (NumPy) array aliases +template +using host_array = nb::ndarray>; + +template +using host_array_2d = nb::ndarray>; + // Register bindings for both regular CUDA and managed-memory arrays. // Usage: // template diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index c89d913a..0ce922ae 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -3,142 +3,688 @@ #include /** - * Kernel to compute tie correction factor for Wilcoxon test. - * Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) where t is the count of tied - * values. + * Fused rank-sum kernel: walk sorted data, compute per-group rank sums + * and tie correction without materializing a rank matrix. * - * Each block handles one column. Uses binary search to find tie groups. - * Assumes input is sorted column-wise (F-order). + * Each thread processes a CONTIGUOUS chunk of sorted elements, detecting + * tie groups by adjacent comparison (sequential access, no binary search). + * Cross-boundary ties are resolved via binary search at chunk boundaries. + * + * When use_gmem is false, per-group accumulators live in shared memory + * (fast atomics, limited to ~1500 groups on 48 KB devices). When use_gmem + * is true, accumulators write directly to ``rank_sums`` in global memory, + * supporting an arbitrary number of groups. The caller must pre-zero + * ``rank_sums`` before launching in the gmem path. + * + * Shared memory layout: + * use_gmem=false: (n_groups + 32) doubles (accumulators + warp buf) + * use_gmem=true: 32 doubles (warp buf only) */ -__global__ void tie_correction_kernel(const double* __restrict__ sorted_vals, - double* __restrict__ correction, - const int n_rows, const int n_cols) { - // Each block handles one column +__global__ void rank_sums_from_sorted_kernel( + const float* __restrict__ sorted_vals, + const int* __restrict__ sorted_row_idx, const int* __restrict__ group_codes, + double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, bool use_gmem) { int col = blockIdx.x; if (col >= n_cols) return; - const double* sv = sorted_vals + (size_t)col * n_rows; + extern __shared__ double smem[]; + + double* grp_sums; + if (use_gmem) { + // Global memory path: write directly to output (must be pre-zeroed) + grp_sums = rank_sums + (size_t)col; // stride: n_cols + } else { + // Shared memory path: per-block accumulators + grp_sums = smem; + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; + } + __syncthreads(); + } + + const float* sv = sorted_vals + (size_t)col * n_rows; + const int* si = sorted_row_idx + (size_t)col * n_rows; - double local_sum = 0.0; - int tid = threadIdx.x; + int chunk = (n_rows + blockDim.x - 1) / blockDim.x; + int my_start = threadIdx.x * chunk; + int my_end = my_start + chunk; + if (my_end > n_rows) my_end = n_rows; - // Each thread processes positions where it detects END of a tie group - // Start from index 1, check if sv[i-1] != sv[i] (boundary detected) - // When at boundary, use binary search to find tie group size - for (int i = tid + 1; i <= n_rows; i += blockDim.x) { - // Detect boundary: either at the end, or value changed - bool at_boundary = (i == n_rows) || (sv[i] != sv[i - 1]); + double local_tie_sum = 0.0; - if (at_boundary) { - // Found end of tie group at position i-1 - // Binary search for start of this tie group - double val = sv[i - 1]; - int lo = 0, hi = i - 1; + // Stride for accumulator indexing: 1 for shared mem, n_cols for global mem + int acc_stride = use_gmem ? n_cols : 1; + + int i = my_start; + while (i < my_end) { + double val = sv[i]; + + int tie_local_end = i + 1; + while (tie_local_end < my_end && sv[tie_local_end] == val) + ++tie_local_end; + + int tie_global_start = i; + if (i == my_start && i > 0 && sv[i - 1] == val) { + int lo = 0, hi = i; while (lo < hi) { - int mid = (lo + hi) / 2; - if (sv[mid] < val) { + int mid = lo + (hi - lo) / 2; + if (sv[mid] < val) lo = mid + 1; - } else { + else hi = mid; - } } - int tie_count = i - lo; + tie_global_start = lo; + } - // t^3 - t for this tie group - double t = (double)tie_count; - local_sum += t * t * t - t; + int tie_global_end = tie_local_end; + if (tie_local_end == my_end && tie_local_end < n_rows && + sv[tie_local_end] == val) { + int lo = tie_local_end, hi = n_rows - 1; + while (lo < hi) { + int mid = hi - ((hi - lo) >> 1); + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; + } + tie_global_end = lo + 1; } - } - // Warp-level reduction using shuffle -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_down_sync(0xffffffff, local_sum, offset); - } + int total_tie = tie_global_end - tie_global_start; + double avg_rank = (double)(tie_global_start + tie_global_end + 1) / 2.0; - // Cross-warp reduction using small shared memory - __shared__ double warp_sums[32]; - int lane = tid & 31; - int warp_id = tid >> 5; + for (int j = i; j < tie_local_end; ++j) { + int grp = group_codes[si[j]]; + if (grp < n_groups) { + atomicAdd(&grp_sums[grp * acc_stride], avg_rank); + } + } - if (lane == 0) { - warp_sums[warp_id] = local_sum; + if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { + double t = (double)total_tie; + local_tie_sum += t * t * t - t; + } + + i = tie_local_end; } + __syncthreads(); - // Final reduction in first warp - // Note: blockDim.x must be a multiple of 32 for correct warp reduction - if (tid < 32) { - double val = (tid < (blockDim.x >> 5)) ? warp_sums[tid] : 0.0; + // Copy shared memory accumulators to global output (smem path only) + if (!use_gmem) { + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * n_cols + col] = grp_sums[g]; + } + } + + if (compute_tie_corr) { + // Warp buf sits after accumulator array in shared memory. + // gmem path: warp buf starts at smem[0]. + // smem path: n_groups doubles, then warp buf. + int warp_buf_off = use_gmem ? 0 : n_groups; + double* warp_buf = smem + warp_buf_off; #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - val += __shfl_down_sync(0xffffffff, val, offset); - } - if (tid == 0) { - double n = (double)n_rows; - double denom = n * n * n - n; - if (denom > 0) { - correction[col] = 1.0 - val / denom; - } else { - correction[col] = 1.0; + for (int off = 16; off > 0; off >>= 1) + local_tie_sum += __shfl_down_sync(0xffffffff, local_tie_sum, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = local_tie_sum; + __syncthreads(); + if (threadIdx.x < 32) { + double val = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - val / denom) : 1.0; } } } } /** - * Kernel to compute average ranks for each column. - * Uses scipy.stats.rankdata 'average' method: ties get the average of the ranks - * they would span. + * Sparse-aware OVR rank-sum kernel for sorted stored values. + * + * After CUB sort the stored values are in ascending order: + * [negatives..., stored_zeros..., positives...] + * Implicit zeros (n_rows − nnz_stored) are inserted analytically + * between negatives and positives to form the full ranking. * - * Each block handles one column. Assumes input is sorted column-wise (F-order). + * Full sorted array (conceptual): + * [negatives..., ALL_zeros (stored+implicit)..., positives...] + * |<- neg_end ->|<------- total_zero -------->| + * + * Rank offsets: + * negative at stored pos i : full pos = i (no shift) + * positive at stored pos i : full pos = i + n_impl_zero (shift right) + * zeros : avg rank = neg_end + (total_zero+1)/2 + * + * Shared-memory layout (doubles): + * grp_sums[n_groups] rank-sum accumulators + * grp_nz_count[n_groups] nonzero-per-group counters + * warp_buf[32] tie-correction reduction scratch + * + * Grid: (sb_cols,) Block: (tpb,) */ -__global__ void average_rank_kernel(const double* __restrict__ sorted_vals, - const int* __restrict__ sorter, - double* __restrict__ ranks, - const int n_rows, const int n_cols) { - // Each thread block handles one column +__global__ void rank_sums_sparse_ovr_kernel( + const float* __restrict__ sorted_vals, + const int* __restrict__ sorted_row_idx, + const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, const double* __restrict__ group_sizes, + double* __restrict__ rank_sums, double* __restrict__ tie_corr, + double* __restrict__ nz_count_scratch, int n_rows, int sb_cols, + int n_groups, bool compute_tie_corr, bool use_gmem) { int col = blockIdx.x; - if (col >= n_cols) return; + if (col >= sb_cols) return; - // Pointers to this column's data - const double* sv = sorted_vals + (size_t)col * n_rows; - const int* si = sorter + (size_t)col * n_rows; - double* rk = ranks + (size_t)col * n_rows; + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + int nnz_stored = seg_end - seg_start; - // Each thread processes multiple rows - for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { - double val = sv[i]; + const float* sv = sorted_vals + seg_start; + const int* si = sorted_row_idx + seg_start; - // Binary search for tie_start (first element equal to val) - int lo = 0, hi = i; + extern __shared__ double smem[]; + double* grp_sums; + double* grp_nz_count; + // Accumulator stride: 1 for shared mem (dense per-block), sb_cols for + // gmem (row-major layout (n_groups, sb_cols) shared across blocks). + int acc_stride; + + if (use_gmem) { + // Output rank_sums doubles as accumulator (pre-zeroed by caller). + grp_sums = rank_sums + (size_t)col; + grp_nz_count = nz_count_scratch + (size_t)col; + acc_stride = sb_cols; + } else { + grp_sums = smem; + grp_nz_count = smem + n_groups; + acc_stride = 1; + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; + grp_nz_count[g] = 0.0; + } + __syncthreads(); + } + + // --- Find zero range: neg_end = first val >= 0, pos_start = first val > 0 + // --- + __shared__ int sh_neg_end; + __shared__ int sh_pos_start; + if (threadIdx.x == 0) { + // Binary search: first index where sv[i] >= 0.0 + int lo = 0, hi = nnz_stored; + while (lo < hi) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] < 0.0f) + lo = mid + 1; + else + hi = mid; + } + sh_neg_end = lo; + // Binary search: first index where sv[i] > 0.0 + hi = nnz_stored; while (lo < hi) { - int mid = (lo + hi) / 2; - if (sv[mid] < val) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] <= 0.0f) lo = mid + 1; - } else { + else hi = mid; + } + sh_pos_start = lo; + } + __syncthreads(); + + int neg_end = sh_neg_end; + int pos_start = sh_pos_start; + int n_stored_zero = pos_start - neg_end; + int n_implicit_zero = n_rows - nnz_stored; + int total_zero = n_implicit_zero + n_stored_zero; + double zero_avg_rank = + (total_zero > 0) ? (double)neg_end + (total_zero + 1.0) / 2.0 : 0.0; + + // Rank offset for positive stored values: + // full_pos(i) = i + n_implicit_zero for i >= pos_start + // So avg_rank for tie group [a,b) of positives: + // = n_implicit_zero + (a + b + 1) / 2 + int offset_pos = n_implicit_zero; + + // --- Count stored values != 0.0 per group --- + for (int i = threadIdx.x; i < nnz_stored; i += blockDim.x) { + if (i < neg_end || i >= pos_start) { // skip stored zeros + int grp = group_codes[si[i]]; + if (grp < n_groups) { + atomicAdd(&grp_nz_count[grp * acc_stride], 1.0); } } - int tie_start = lo; + } + __syncthreads(); - // Binary search for tie_end (last element equal to val) - lo = i; - hi = n_rows - 1; - while (lo < hi) { - int mid = (lo + hi + 1) / 2; - if (sv[mid] > val) { - hi = mid - 1; - } else { - lo = mid; + // --- Zero-rank contribution per group --- + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + double n_zero_in_g = group_sizes[g] - grp_nz_count[g * acc_stride]; + grp_sums[g * acc_stride] = n_zero_in_g * zero_avg_rank; + } + __syncthreads(); + + // --- Walk ALL stored values, skip stored zeros, compute ranks --- + // Chunk over [0, nnz_stored), skip [neg_end, pos_start). + int chunk = (nnz_stored + blockDim.x - 1) / blockDim.x; + int my_start = threadIdx.x * chunk; + int my_end = my_start + chunk; + if (my_end > nnz_stored) my_end = nnz_stored; + + double local_tie_sum = 0.0; + + int i = my_start; + while (i < my_end) { + // Skip stored zeros + if (i >= neg_end && i < pos_start) { + i = pos_start; + continue; + } + + float val = sv[i]; + + int tie_local_end = i + 1; + while (tie_local_end < my_end && sv[tie_local_end] == val) + ++tie_local_end; + // Don't let local tie range cross into stored-zero region + if (val < 0.0f && tie_local_end > neg_end) tie_local_end = neg_end; + + int tie_global_start = i; + if (i == my_start && i > 0 && sv[i - 1] == val) { + // Binary search for first occurrence + int search_lo = (val < 0.0f) ? 0 : pos_start; + int lo = search_lo, hi = i; + while (lo < hi) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] < val) + lo = mid + 1; + else + hi = mid; + } + tie_global_start = lo; + } + // Handle thread resuming at pos_start after skipping zeros + if (i == pos_start && i > 0 && pos_start > neg_end && + val == sv[pos_start] && i != my_start) { + // Already at pos_start boundary, tie_global_start = pos_start + tie_global_start = pos_start; + } + + int tie_global_end = tie_local_end; + if (tie_local_end == my_end && tie_local_end < nnz_stored && + tie_local_end != neg_end && sv[tie_local_end] == val) { + int search_hi = (val < 0.0f) ? (neg_end - 1) : (nnz_stored - 1); + int lo = tie_local_end, hi = search_hi; + while (lo < hi) { + int mid = hi - ((hi - lo) >> 1); + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; + } + tie_global_end = lo + 1; + } + + int total_tie = tie_global_end - tie_global_start; + + // Rank depends on sign: + // negative (i < neg_end): full pos = stored pos (no shift) + // positive (i >= pos_start): full pos = stored pos + n_implicit_zero + double avg_rank; + if (val < 0.0f) { + avg_rank = (double)(tie_global_start + tie_global_end + 1) / 2.0; + } else { + avg_rank = (double)offset_pos + + (double)(tie_global_start + tie_global_end + 1) / 2.0; + } + + for (int j = i; j < tie_local_end; ++j) { + int grp = group_codes[si[j]]; + if (grp < n_groups) { + atomicAdd(&grp_sums[grp * acc_stride], avg_rank); + } + } + + if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { + double t = (double)total_tie; + local_tie_sum += t * t * t - t; + } + + i = tie_local_end; + } + + __syncthreads(); + + // Write rank sums to global output (smem path only — gmem path is direct) + if (!use_gmem) { + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * sb_cols + col] = grp_sums[g]; + } + } + + // Tie correction: warp + block reduction + if (compute_tie_corr) { + // Zero tie group contribution (one thread only) + if (threadIdx.x == 0 && total_zero > 1) { + double tz = (double)total_zero; + local_tie_sum += tz * tz * tz - tz; + } + + // smem path: warp buf after both accumulator arrays (2 * n_groups). + // gmem path: accumulators are in gmem, warp buf starts at smem[0]. + int warp_buf_off = use_gmem ? 0 : 2 * n_groups; + double* warp_buf = smem + warp_buf_off; + +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_tie_sum += __shfl_down_sync(0xffffffff, local_tie_sum, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = local_tie_sum; + __syncthreads(); + if (threadIdx.x < 32) { + double v = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v += __shfl_down_sync(0xffffffff, v, off); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - v / denom) : 1.0; } } - int tie_end = lo; + } +} - // Average rank for ties: (start + end + 2) / 2 (1-based ranks) - double avg_rank = (double)(tie_start + tie_end + 2) / 2.0; +/** + * Decide whether the host cast+stats kernels can use per-block shared memory + * accumulators. Large group counts exceed the dynamic smem launch limit, so + * those cases fall back to direct global-memory atomics after zeroing the + * per-stream output buffers. + */ +static int wilcoxon_cast_max_smem_per_block() { + static int cached = -1; + if (cached < 0) { + int device; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&cached, cudaDevAttrMaxSharedMemoryPerBlock, + device); + } + return cached; +} + +static size_t cast_accumulate_smem_config(int n_groups, bool compute_sq_sums, + bool compute_nnz, bool& use_gmem) { + int n_arrays = 1 + (compute_sq_sums ? 1 : 0) + (compute_nnz ? 1 : 0); + size_t need = (size_t)n_arrays * n_groups * sizeof(double); + if (need <= (size_t)wilcoxon_cast_max_smem_per_block()) { + use_gmem = false; + return need; + } + use_gmem = true; + return 0; +} - // Write rank to original position - rk[si[i]] = avg_rank; +/** + * Pre-sort cast-and-accumulate kernel for dense OVR host streaming. + * + * Reads a sub-batch block in its native host dtype (InT = float or double), + * writes a float32 copy used as the sort input, and accumulates per-group + * sum, sum-of-squares and nonzero counts in float64. Stats are derived + * from the original-precision values so float64 host input keeps its + * precision while the sort still runs on float32 keys. + * + * Block-per-column layout (grid: (sb_cols,), block: (tpb,)). + * Shared memory: 3 * n_groups doubles (s_sum, s_sq, s_nnz). + */ +template +__global__ void ovr_cast_and_accumulate_dense_kernel( + const InT* __restrict__ block_in, float* __restrict__ block_f32_out, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + extern __shared__ double smem[]; + double* s_sum = smem; + double* s_sq = smem + n_groups; + double* s_nnz = smem + 2 * n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + s_sum[g] = 0.0; + if (compute_sq_sums) s_sq[g] = 0.0; + if (compute_nnz) s_nnz[g] = 0.0; + } + __syncthreads(); + + const InT* src = block_in + (size_t)col * n_rows; + float* dst = block_f32_out + (size_t)col * n_rows; + + for (int r = threadIdx.x; r < n_rows; r += blockDim.x) { + InT v_in = src[r]; + double v = (double)v_in; + dst[r] = (float)v_in; + int g = group_codes[r]; + if (g < n_groups) { + atomicAdd(&s_sum[g], v); + if (compute_sq_sums) atomicAdd(&s_sq[g], v * v); + if (compute_nnz && v != 0.0) atomicAdd(&s_nnz[g], 1.0); + } + } + __syncthreads(); + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + group_sums[(size_t)g * sb_cols + col] = s_sum[g]; + if (compute_sq_sums) { + group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; + } + if (compute_nnz) { + group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; + } + } +} + +template +__global__ void ovr_cast_and_accumulate_dense_global_kernel( + const InT* __restrict__ block_in, float* __restrict__ block_f32_out, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + const InT* src = block_in + (size_t)col * n_rows; + float* dst = block_f32_out + (size_t)col * n_rows; + + for (int r = threadIdx.x; r < n_rows; r += blockDim.x) { + InT v_in = src[r]; + double v = (double)v_in; + dst[r] = (float)v_in; + int g = group_codes[r]; + if (g < n_groups) { + atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)g * sb_cols + col], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); + } + } + } +} + +/** + * Pre-sort cast-and-accumulate kernel for sparse OVR host streaming. + * + * Sub-batch CSC data is laid out contiguously: values for column c live + * at positions [col_seg_offsets[c], col_seg_offsets[c+1]). For each + * stored value, read the native-dtype InT, write a float32 copy for the + * CUB sort, and accumulate per-group sum/sum-sq/nnz in float64. Implicit + * zeros contribute nothing to any of these stats. + * + * Block-per-column layout (grid: (sb_cols,), block: (tpb,)). + * Shared memory: 3 * n_groups doubles. + */ +template +__global__ void ovr_cast_and_accumulate_sparse_kernel( + const InT* __restrict__ data_in, float* __restrict__ data_f32_out, + const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + + extern __shared__ double smem[]; + double* s_sum = smem; + double* s_sq = smem + n_groups; + double* s_nnz = smem + 2 * n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + s_sum[g] = 0.0; + if (compute_sq_sums) s_sq[g] = 0.0; + if (compute_nnz) s_nnz[g] = 0.0; + } + __syncthreads(); + + for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { + InT v_in = data_in[i]; + double v = (double)v_in; + data_f32_out[i] = (float)v_in; + int row = (int)indices[i]; + int g = group_codes[row]; + if (g < n_groups) { + atomicAdd(&s_sum[g], v); + if (compute_sq_sums) atomicAdd(&s_sq[g], v * v); + if (compute_nnz && v != 0.0) atomicAdd(&s_nnz[g], 1.0); + } + } + __syncthreads(); + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + group_sums[(size_t)g * sb_cols + col] = s_sum[g]; + if (compute_sq_sums) { + group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; + } + if (compute_nnz) { + group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; + } + } +} + +template +__global__ void ovr_cast_and_accumulate_sparse_global_kernel( + const InT* __restrict__ data_in, float* __restrict__ data_f32_out, + const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + + for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { + InT v_in = data_in[i]; + double v = (double)v_in; + data_f32_out[i] = (float)v_in; + int row = (int)indices[i]; + int g = group_codes[row]; + if (g < n_groups) { + atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)g * sb_cols + col], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); + } + } + } +} + +template +static void launch_ovr_cast_and_accumulate_dense( + const InT* d_block_orig, float* d_block_f32, const int* d_group_codes, + double* d_group_sums, double* d_group_sq_sums, double* d_group_nnz, + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums, + bool compute_nnz, int tpb, size_t smem_cast, bool use_gmem, + cudaStream_t stream) { + if (use_gmem) { + size_t stats_items = (size_t)n_groups * sb_cols; + cudaMemsetAsync(d_group_sums, 0, stats_items * sizeof(double), stream); + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, stats_items * sizeof(double), + stream); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, stats_items * sizeof(double), + stream); + } + ovr_cast_and_accumulate_dense_global_kernel + <<>>( + d_block_orig, d_block_f32, d_group_codes, d_group_sums, + d_group_sq_sums, d_group_nnz, n_rows, sb_cols, n_groups, + compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_global_kernel); + } else { + ovr_cast_and_accumulate_dense_kernel + <<>>( + d_block_orig, d_block_f32, d_group_codes, d_group_sums, + d_group_sq_sums, d_group_nnz, n_rows, sb_cols, n_groups, + compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_kernel); + } +} + +template +static void launch_ovr_cast_and_accumulate_sparse( + const InT* d_data_orig, float* d_data_f32, const IndexT* d_indices, + const int* d_col_offsets, const int* d_group_codes, double* d_group_sums, + double* d_group_sq_sums, double* d_group_nnz, int sb_cols, int n_groups, + bool compute_sq_sums, bool compute_nnz, int tpb, size_t smem_cast, + bool use_gmem, cudaStream_t stream) { + if (use_gmem) { + size_t stats_items = (size_t)n_groups * sb_cols; + cudaMemsetAsync(d_group_sums, 0, stats_items * sizeof(double), stream); + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, stats_items * sizeof(double), + stream); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, stats_items * sizeof(double), + stream); + } + ovr_cast_and_accumulate_sparse_global_kernel + <<>>( + d_data_orig, d_data_f32, d_indices, d_col_offsets, + d_group_codes, d_group_sums, d_group_sq_sums, d_group_nnz, + sb_cols, n_groups, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_global_kernel); + } else { + ovr_cast_and_accumulate_sparse_kernel + <<>>( + d_data_orig, d_data_f32, d_indices, d_col_offsets, + d_group_codes, d_group_sums, d_group_sq_sums, d_group_nnz, + sb_cols, n_groups, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh new file mode 100644 index 00000000..4bb38b18 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -0,0 +1,791 @@ +#pragma once + +#include + +// ============================================================================ +// Warp reduction helper (sum doubles across block via warp_buf) +// ============================================================================ + +__device__ __forceinline__ double block_reduce_sum(double val, + double* warp_buf) { +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = val; + __syncthreads(); + if (threadIdx.x < 32) { + double v2 = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v2 += __shfl_down_sync(0xffffffff, v2, off); + return v2; // only lane 0 of warp 0 has the final result + } + return 0.0; +} + +// ============================================================================ +// Parallel tie correction — all threads collaborate. +// +// For each unique value in the combined sorted (ref, grp) arrays, accumulate +// t^3 - t where t = count of that value. Uses two passes: +// 1. Iterate unique values in ref_col, count in both arrays. +// 2. Iterate unique values in grp_col that do NOT appear in ref_col. +// +// Incremental binary search bounds exploit monotonicity within each thread's +// stride to reduce total search work. +// +// Caller must __syncthreads() before calling. warp_buf is reused for +// reduction (32 doubles, shared memory). +// ============================================================================ + +__device__ __forceinline__ void compute_tie_correction_parallel( + const float* ref_col, int n_ref, const float* grp_col, int n_grp, + double* warp_buf, double* out) { + double local_tie = 0.0; + + // Pass 1: unique values in ref_col + int grp_lb = 0, grp_ub = 0; + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + + // Count in ref: upper_bound from i+1 + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_ref = lo - i; + + // Count in grp: incremental lower/upper bound + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] < v) + lo = m + 1; + else + hi = m; + } + int lb = lo; + grp_lb = lb; + + lo = (grp_ub > lb) ? grp_ub : lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_grp = lo - lb; + grp_ub = lo; + + int cnt = cnt_ref + cnt_grp; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + // Pass 2: unique values in grp_col that are absent from ref_col + int ref_lb = 0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + if (i == 0 || grp_col[i] != grp_col[i - 1]) { + float v = grp_col[i]; + + // Incremental lower_bound in ref + int lo = ref_lb, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + ref_lb = lo; + + if (lo >= n_ref || ref_col[lo] != v) { + // Value not in ref — count in grp only (upper_bound from i+1) + lo = i + 1; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt = lo - i; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + } + + // Block-wide reduction + double tie_sum = block_reduce_sum(local_tie, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + *out = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +// ============================================================================ +// Batched rank sums — pre-sorted (binary search, no shared memory sort) +// Used by the OVO streaming pipeline in wilcoxon_streaming.cu. +// +// Incremental binary search: each thread carries forward lower/upper bound +// positions across loop iterations, exploiting the monotonicity of the +// sorted grp_col values within each thread's stride. +// ============================================================================ + +__global__ void batched_rank_sums_presorted_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_sorted, + const int* __restrict__ grp_offsets, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int skip_n_grp_le /*= 0*/) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + // Size-gated dispatch (see ovo_fused_sort_rank_kernel for the contract). + if (n_grp <= skip_n_grp_le) return; + + if (n_grp == 0) { + if (threadIdx.x == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + const float* ref_col = ref_sorted + (long long)col * n_ref; + const float* grp_col = grp_sorted + (long long)col * n_all_grp + g_start; + + // Incremental binary search bounds (advance monotonically per thread) + int ref_lb = 0, ref_ub = 0; + int grp_lb = 0, grp_ub = 0; + double local_sum = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_col[i]; + int lo, hi; + + // Lower bound in ref (from ref_lb) + lo = ref_lb; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + ref_lb = n_lt_ref; + + // Upper bound in ref (from max(ref_ub, n_lt_ref)) + lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + ref_ub = lo; + + // Lower bound in grp (from grp_lb) + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_grp = lo; + grp_lb = n_lt_grp; + + // Upper bound in grp (from max(grp_ub, n_lt_grp)) + lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + grp_ub = lo; + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + } + + __shared__ double warp_buf[32]; + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + compute_tie_correction_parallel(ref_col, n_ref, grp_col, n_grp, warp_buf, + &tie_corr[grp * n_cols + col]); +} + +// ============================================================================ +// Tier 1 fused kernel: smem bitonic sort + binary search rank sums +// For small groups (< ~2K cells). No CUB, no global memory sort buffers. +// Grid: (n_cols, n_groups), Block: min(padded_grp_size, 512) +// Shared memory: padded_grp_size floats + 32 doubles (warp reduction) +// ============================================================================ + +__global__ void ovo_fused_sort_rank_kernel( + const float* __restrict__ ref_sorted, // F-order (n_ref, n_cols) sorted + const float* __restrict__ grp_dense, // F-order (n_all_grp, n_cols) + // unsorted + const int* __restrict__ grp_offsets, // (n_groups + 1,) + double* __restrict__ rank_sums, // (n_groups, n_cols) row-major + double* __restrict__ tie_corr, // (n_groups, n_cols) row-major + int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, + int padded_grp_size, int skip_n_grp_le /*= 0*/) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + // Size-gated dispatch: when co-launched with the Tier 0 warp kernel we + // skip groups it's already handling. Each group owns its own + // rank_sums row, so the two kernels' writes never alias. + if (n_grp <= skip_n_grp_le) return; + + if (n_grp == 0) { + if (threadIdx.x == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + // Shared memory: [padded_grp_size floats | 32 doubles for warp reduction] + extern __shared__ char smem_raw[]; + float* grp_smem = (float*)smem_raw; + double* warp_buf = (double*)(smem_raw + padded_grp_size * sizeof(float)); + + // Load group data into shared memory, pad with +INF + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) + grp_smem[i] = grp_col[i]; + for (int i = n_grp + threadIdx.x; i < padded_grp_size; i += blockDim.x) + grp_smem[i] = __int_as_float(0x7f800000); // +INF + __syncthreads(); + + // Bitonic sort in shared memory + for (int k = 2; k <= padded_grp_size; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + for (int i = threadIdx.x; i < padded_grp_size; i += blockDim.x) { + int ixj = i ^ j; + if (ixj > i) { + bool asc = ((i & k) == 0); + float a = grp_smem[i], b = grp_smem[ixj]; + if (asc ? (a > b) : (a < b)) { + grp_smem[i] = b; + grp_smem[ixj] = a; + } + } + } + __syncthreads(); + } + } + + // Binary search each sorted grp element against sorted ref + // Incremental bounds: values are monotonic within each thread's stride + const float* ref_col = ref_sorted + (long long)col * n_ref; + int ref_lb = 0, ref_ub = 0; + int grp_lb = 0, grp_ub = 0; + double local_sum = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_smem[i]; + int lo, hi; + + lo = ref_lb; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + ref_lb = n_lt_ref; + + lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + ref_ub = lo; + + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_smem[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_grp = lo; + grp_lb = n_lt_grp; + + lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_smem[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + grp_ub = lo; + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + } + + // Block reduction → write rank_sums + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + // Parallel tie correction (grp_smem is sorted shared memory) + compute_tie_correction_parallel(ref_col, n_ref, grp_smem, n_grp, warp_buf, + &tie_corr[grp * n_cols + col]); +} + +// ============================================================================ +// Tier 2 helper: tie contribution of the sorted reference alone. +// One block per column. The medium unsorted-rank kernel uses this as a base +// and only adds group-only/overlap deltas from the unsorted group values. +// ============================================================================ + +__global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted, + double* __restrict__ ref_tie_sums, int n_ref, + int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + const float* ref_col = ref_sorted + (long long)col * n_ref; + + double local_tie = 0.0; + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt = lo - i; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + __shared__ double warp_buf[32]; + double total = block_reduce_sum(local_tie, warp_buf); + if (threadIdx.x == 0) ref_tie_sums[col] = total; +} + +// ============================================================================ +// Tier 2 fused kernel: no-sort direct rank for medium groups. +// +// Avoids the smem bitonic sort for groups in (skip_n_grp_le, +// max_n_grp_le]. Ranks are computed from ref binary searches plus an +// in-group scan over unsorted shared values. Tie correction starts from +// ref_tie_sums[col] and adds only group-only / ref-overlap deltas. +// ============================================================================ + +__global__ void ovo_medium_unsorted_rank_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int skip_n_grp_le, int max_n_grp_le) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + if (n_grp <= skip_n_grp_le || n_grp > max_n_grp_le) return; + + extern __shared__ char smem_raw[]; + float* grp_smem = (float*)smem_raw; + double* warp_buf = (double*)(smem_raw + max_n_grp_le * sizeof(float)); + + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) + grp_smem[i] = grp_col[i]; + __syncthreads(); + + const float* ref_col = ref_sorted + (long long)col * n_ref; + double local_sum = 0.0; + double local_tie_delta = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_smem[i]; + + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + int n_lt_grp = 0; + int n_eq_grp = 0; + bool first_in_grp = true; + for (int j = 0; j < n_grp; ++j) { + float w = grp_smem[j]; + if (w < v) ++n_lt_grp; + if (w == v) { + ++n_eq_grp; + if (j < i) first_in_grp = false; + } + } + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + + if (compute_tie_corr && first_in_grp) { + double cg = (double)n_eq_grp; + double cr = (double)n_eq_ref; + double group_tie = (cg > 1.0) ? (cg * cg * cg - cg) : 0.0; + local_tie_delta += group_tie; + if (cr > 0.0) { + double combined = cr + cg; + double ref_tie = (cr > 1.0) ? (cr * cr * cr - cr) : 0.0; + local_tie_delta += combined * combined * combined - combined - + ref_tie - group_tie; + } + } + } + + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + double tie_delta = block_reduce_sum(local_tie_delta, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + double tie_sum = ref_tie_sums[col] + tie_delta; + tie_corr[grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +// ============================================================================ +// Warp-scoped tie correction for Tier 0. +// +// Sorted values live in a 32-lane register (one per lane, with unused lanes +// carrying +INF). Walks unique values via lane-step differentials and +// counts ties across the sorted ref column via binary search. All the +// sync is __syncwarp — no smem, no __syncthreads. +// ============================================================================ + +__device__ __forceinline__ double tier0_tie_sum_warp(const float* ref_col, + int n_ref, float v_lane, + int n_grp, + unsigned int active_mask) { + int lane = threadIdx.x & 31; + double local_tie = 0.0; + + // Pass 1: for each unique value in ref_col, count occurrences in ref and + // in the sorted group (held in register v_lane across 32 lanes). + for (int base = 0; base < n_ref; base += 32) { + int i = base + lane; + bool in_ref_lane = (i < n_ref); + float v = in_ref_lane ? ref_col[i] : 0.0f; + bool is_first = in_ref_lane && ((i == 0) || (v != ref_col[i - 1])); + int cnt_ref = 0; + if (is_first) { + // Count in ref: upper_bound from i+1 + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + cnt_ref = lo - i; + } + + // Count in grp: look up how many lanes hold v_lane == v. All lanes + // execute the shuffle loop; only lanes owning a unique ref value use + // the result. + int cnt_grp = 0; +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + float vi = __shfl_sync(0xffffffff, v_lane, lane_i); + if (is_first && lane_i < n_grp && vi == v) ++cnt_grp; + } + + if (is_first) { + int cnt = cnt_ref + cnt_grp; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + // Pass 2: unique values in grp that are absent from ref. + // Walk lanes 0..n_grp-1; for each lane whose v differs from prev lane's, + // binary-search ref for v. If not present, count consecutive matching + // lanes (tie block). + if (lane < n_grp) { + float v = v_lane; + float prev_lane_v = + __shfl_sync(active_mask, v_lane, (lane > 0) ? lane - 1 : 0); + float v_prev = + (lane > 0) ? prev_lane_v : __int_as_float(0xff800000); // -INF + bool first_in_grp = (lane == 0) || (v != v_prev); + bool in_ref = false; + if (first_in_grp) { + // Binary search in ref. + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + in_ref = (lo < n_ref) && (ref_col[lo] == v); + } + + // Count how many lanes ≥ this lane hold the same v. Keep the shuffle + // uniform across active lanes even though only unique, ref-absent + // group values consume the count. + int cnt = 0; +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + int src_lane = (lane_i < n_grp) ? lane_i : 0; + float vi = __shfl_sync(active_mask, v_lane, src_lane); + if (first_in_grp && !in_ref && lane_i >= lane && lane_i < n_grp && + vi == v) { + ++cnt; + } + } + if (first_in_grp && !in_ref && cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + + // Warp reduce. +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_tie += __shfl_down_sync(0xffffffff, local_tie, off); + return local_tie; // meaningful on lane 0. +} + +// ============================================================================ +// Tier 0 fused kernel: warp-per-(col, group) pair, 8 warps packed per block. +// +// Each warp independently: +// 1. Loads ≤ 32 group values into a single register (one per lane, +// padded with +INF). +// 2. Bitonic-sorts via __shfl_xor_sync — no smem, no __syncthreads. +// 3. Binary-searches into sorted ref for each lane's value and +// accumulates the rank-sum term. +// 4. Warp-shuffle reduces to lane 0 and writes rank_sums / tie_corr. +// +// 8 (col, group) pairs per block cuts block count 8× vs the block-per-pair +// Tier 1, and the lack of __syncthreads / smem sort lets each warp run +// independently at full throughput. +// +// Grid: (n_cols, ceil(n_groups / 8)), Block: 256. +// ============================================================================ + +__global__ void ovo_warp_sort_rank_kernel(const float* __restrict__ ref_sorted, + const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + double* __restrict__ rank_sums, + double* __restrict__ tie_corr, + int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr) { + constexpr int WARPS_PER_BLOCK = 8; + int warp_id = threadIdx.x >> 5; + int lane = threadIdx.x & 31; + + int col = blockIdx.x; + int grp = blockIdx.y * WARPS_PER_BLOCK + warp_id; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + // This kernel only handles groups that fit in a single warp (one value + // per lane). Larger groups are delegated to Tier 1/3 in a co-launched + // kernel; since each group owns its own row in rank_sums/tie_corr, the + // two kernels interlace into the output without conflict. + if (n_grp > TIER0_GROUP_THRESHOLD) return; + + if (n_grp == 0) { + if (lane == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + // One value per lane, pad with +INF so sort pushes them to the end. + const float POS_INF = __int_as_float(0x7f800000); + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + float x = (lane < n_grp) ? grp_col[lane] : POS_INF; + unsigned int active_mask = __ballot_sync(0xffffffff, lane < n_grp); + + // Warp-shuffle bitonic sort (ascending) — 32 elements in registers. + for (int k = 1; k <= 16; k <<= 1) { + for (int j = k; j > 0; j >>= 1) { + float y = __shfl_xor_sync(0xffffffff, x, j); + bool asc = (((lane & (k << 1)) == 0)); + bool take_min = (((lane & j) == 0) == asc); + x = take_min ? fminf(x, y) : fmaxf(x, y); + } + } + + // After sort, x[lane] holds the lane-th smallest group value (lanes + // ≥ n_grp hold +INF). Binary-search each value into the sorted ref. + const float* ref_col = ref_sorted + (long long)col * n_ref; + double local_sum = 0.0; + + if (lane < n_grp) { + float v = x; + // Lower bound in ref. + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + // Upper bound in ref. + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + // In-group counts: in the sorted warp-register x, count lanes < this + // one that hold strictly less, and lanes with equal value. + int n_lt_grp = 0; + int n_eq_grp_offset = 0; // tied lanes strictly before this one + int n_eq_grp_after = 1; // count self +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + if (lane_i >= n_grp) continue; + float vi = __shfl_sync(active_mask, v, lane_i); + if (lane_i < lane) { + if (vi < v) + ++n_lt_grp; + else if (vi == v) + ++n_eq_grp_offset; + } else if (lane_i > lane) { + if (vi == v) ++n_eq_grp_after; + } + } + int n_eq_grp_total = n_eq_grp_offset + n_eq_grp_after; + // Contribution: rank = n_lt_ref + n_lt_grp + (n_eq_ref + + // n_eq_grp_total + 1) / 2, but we sum per lane so each tie lane + // gets the same mid-rank. This matches the Tier 1 accumulation. + local_sum = (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp_total) + 1.0) / 2.0; + } + + // Warp reduce. +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_sum += __shfl_down_sync(0xffffffff, local_sum, off); + if (lane == 0) rank_sums[grp * n_cols + col] = local_sum; + + if (!compute_tie_corr) return; + + // Warp-scoped tie correction. + double tie_sum = tier0_tie_sum_warp(ref_col, n_ref, x, n_grp, active_mask); + if (lane == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + tie_corr[grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu deleted file mode 100644 index d25f7d0f..00000000 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ /dev/null @@ -1,72 +0,0 @@ -#include -#include "../nb_types.h" - -#include "kernels_wilcoxon.cuh" - -using namespace nb::literals; - -// Constants for kernel launch configuration -constexpr int WARP_SIZE = 32; -constexpr int MAX_THREADS_PER_BLOCK = 512; - -static inline int round_up_to_warp(int n) { - int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; -} - -static inline void launch_tie_correction(const double* sorted_vals, - double* correction, int n_rows, - int n_cols, cudaStream_t stream) { - int threads_per_block = round_up_to_warp(n_rows); - dim3 block(threads_per_block); - dim3 grid(n_cols); - tie_correction_kernel<<>>(sorted_vals, correction, - n_rows, n_cols); - CUDA_CHECK_LAST_ERROR(tie_correction_kernel); -} - -static inline void launch_average_rank(const double* sorted_vals, - const int* sorter, double* ranks, - int n_rows, int n_cols, - cudaStream_t stream) { - int threads_per_block = round_up_to_warp(n_rows); - dim3 block(threads_per_block); - dim3 grid(n_cols); - average_rank_kernel<<>>(sorted_vals, sorter, ranks, - n_rows, n_cols); - CUDA_CHECK_LAST_ERROR(average_rank_kernel); -} - -template -void register_bindings(nb::module_& m) { - m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; - - // Tie correction kernel - m.def( - "tie_correction", - [](gpu_array_f sorted_vals, - gpu_array correction, int n_rows, int n_cols, - std::uintptr_t stream) { - launch_tie_correction(sorted_vals.data(), correction.data(), n_rows, - n_cols, (cudaStream_t)stream); - }, - "sorted_vals"_a, "correction"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, - "stream"_a = 0); - - // Average rank kernel - m.def( - "average_rank", - [](gpu_array_f sorted_vals, - gpu_array_f sorter, - gpu_array_f ranks, int n_rows, int n_cols, - std::uintptr_t stream) { - launch_average_rank(sorted_vals.data(), sorter.data(), ranks.data(), - n_rows, n_cols, (cudaStream_t)stream); - }, - "sorted_vals"_a, "sorter"_a, "ranks"_a, nb::kw_only(), "n_rows"_a, - "n_cols"_a, "stream"_a = 0); -} - -NB_MODULE(_wilcoxon_cuda, m) { - REGISTER_GPU_BINDINGS(register_bindings, m); -} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh new file mode 100644 index 00000000..8ac0f247 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_common.cuh @@ -0,0 +1,260 @@ +#pragma once + +#include +#include + +#include +#include +#if __has_include() +#include // rmm >= 26.02 +#else +#include // rmm 25.x +#endif + +#include "../nb_types.h" // for CUDA_CHECK_LAST_ERROR + +constexpr int WARP_SIZE = 32; +constexpr int MAX_THREADS_PER_BLOCK = 512; +constexpr int N_STREAMS = 4; +constexpr int SUB_BATCH_COLS = 64; +constexpr int BEGIN_BIT = 0; +constexpr int END_BIT = 32; +// Default thread-per-block for utility kernels (extract, gather, offsets, +// etc.). +constexpr int UTIL_BLOCK_SIZE = 256; +// Scratch slots for warp-level reduction (one slot per warp, 32 warps max). +constexpr int WARP_REDUCE_BUF = 32; +// Max group size for the super-fast "warp-per-(col,group)" fused kernel +// (Tier 0). Each warp sorts and ranks one (col, group) pair entirely in +// registers via warp-shuffle bitonic sort — no smem sort buffer, no +// __syncthreads(). Blocks pack 8 warps so block launch overhead is +// amortised 8× across (col, group) work items. This path is the fast +// route for per-celltype perturbation-style workloads where most test +// groups have only a few dozen cells. +constexpr int TIER0_GROUP_THRESHOLD = 32; +// Medium-group cutoff for the unsorted direct-rank kernel. For perturbation +// workloads most groups sit below this range, where avoiding a full smem +// bitonic sort wins despite the O(n^2) in-group count. +constexpr int TIER2_GROUP_THRESHOLD = 512; +// Max group size for the fused smem-sort rank kernel (Tier 1 fast path). +// Beyond this, fall back to CUB segmented sort + binary-search rank kernel. +constexpr int TIER1_GROUP_THRESHOLD = 2500; +// Per-stream dense slab budget (float32 items). Dynamic sub-batching sizes +// each group's column batch so that (n_g × eff_sb_cols) ≤ this. Bigger = +// fewer kernel launches; smaller = less per-stream memory. 64M items × 4B = +// 256 MB per stream dense slab + same for sorted copy ≈ 512 MB / stream. +constexpr size_t GROUP_DENSE_BUDGET_ITEMS = 64 * 1024 * 1024; + +// --------------------------------------------------------------------------- +// RAII guard for cudaHostRegister. Unregisters on scope exit even when an +// exception unwinds — prevents leaked host pinning on stream-sync failures. +// --------------------------------------------------------------------------- +struct HostRegisterGuard { + void* ptr = nullptr; + + HostRegisterGuard() = default; + HostRegisterGuard(void* p, size_t bytes, unsigned int flags = 0) { + if (p && bytes > 0) { + cudaError_t err = cudaHostRegister(p, bytes, flags); + if (err != cudaSuccess) { + // Already-registered memory is fine; anything else means the + // subsequent kernels would read garbage from an unmapped + // pointer, so surface the error immediately. + if (err == cudaErrorHostMemoryAlreadyRegistered) { + cudaGetLastError(); // clear sticky error flag + } else { + throw std::runtime_error( + std::string("cudaHostRegister failed (") + + std::to_string((size_t)bytes) + + " bytes, flags=" + std::to_string(flags) + + "): " + cudaGetErrorString(err)); + } + } else { + ptr = p; + } + } + } + ~HostRegisterGuard() { + if (ptr) cudaHostUnregister(ptr); + } + HostRegisterGuard(const HostRegisterGuard&) = delete; + HostRegisterGuard& operator=(const HostRegisterGuard&) = delete; + HostRegisterGuard(HostRegisterGuard&& other) noexcept : ptr(other.ptr) { + other.ptr = nullptr; + } + HostRegisterGuard& operator=(HostRegisterGuard&& other) noexcept { + if (this != &other) { + if (ptr) cudaHostUnregister(ptr); + ptr = other.ptr; + other.ptr = nullptr; + } + return *this; + } +}; + +// --------------------------------------------------------------------------- +// RMM pool helper — allocate GPU buffers through the current RMM memory +// resource. Buffers are stored in a vector and freed (RAII) when the vector +// is destroyed. +// --------------------------------------------------------------------------- +struct RmmPool { + std::vector bufs; + rmm::device_async_resource_ref mr; + + RmmPool() : mr(rmm::mr::get_current_device_resource()) { + } + + template + T* alloc(size_t count) { + bufs.emplace_back(count * sizeof(T), rmm::cuda_stream_default, mr); + return static_cast(bufs.back().data()); + } +}; + +static inline int round_up_to_warp(int n) { + int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; +} + +/** Fill linear segment offsets [0, stride, 2*stride, ..., n_segments*stride] + * on-device. One thread per output slot. */ +__global__ void fill_linear_offsets_kernel(int* __restrict__ out, + int n_segments, int stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i <= n_segments) out[i] = i * stride; +} + +/** Fill per-row stats codes for a pack of K groups. + * Given pack_grp_offsets (size K+1, relative to pack start), write + * stats_codes[r] = base_slot + group_idx_of_row_r for r in [0, pack_n_rows). + * Binary search within the K+1 offsets. */ +__global__ void fill_pack_stats_codes_kernel( + const int* __restrict__ pack_grp_offsets, int* __restrict__ stats_codes, + int K, int base_slot) { + int r = blockIdx.x * blockDim.x + threadIdx.x; + int pack_n_rows = pack_grp_offsets[K]; + if (r >= pack_n_rows) return; + int lo = 0, hi = K; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (pack_grp_offsets[m + 1] <= r) + lo = m + 1; + else + hi = m; + } + stats_codes[r] = base_slot + lo; +} + +/** Rebase a slice of indptr: out[i] = indptr[col + i] - indptr[col]. + * Grid-strided: supports arbitrary `count` (no single-block thread limit). + * Templated so that 64-bit global indptrs can produce 32-bit pack-local + * indptrs (per-pack nnz always fits in int32 thanks to the memory budget). + */ +template +__global__ void rebase_indptr_kernel(const IdxIn* __restrict__ indptr, + IdxOut* __restrict__ out, int col, + int count) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < count) out[i] = (IdxOut)(indptr[col + i] - indptr[col]); +} + +/** Fused gather + cast-to-float32 + stats accumulation, reading from mapped + * pinned host memory. Block-per-row; threads in the block cooperate on the + * row's nnz. Each nnz is read from host over PCIe exactly once — no + * intermediate native-dtype GPU buffer, no second GPU pass. + * + * h_data / h_indices: device-accessible pointers into mapped pinned host + * memory (cudaHostRegisterMapped). + * d_indptr_full: full-matrix indptr on device. + * d_row_ids: rows to gather (size n_target_rows). + * d_out_indptr: pre-computed compacted indptr, size n_target_rows+1 with + * out_indptr[i+1] - out_indptr[i] equal to the source row's + * nnz. + * + * Slot dispatch: + * d_stats_codes != nullptr → slot = d_stats_codes[r]; otherwise slot = + * fixed_slot (used for the Ref phase where every row maps to the same + * slot). slot ∉ [0, n_groups_stats) skips accumulation. + */ +template +__global__ void csr_gather_cast_accumulate_mapped_kernel( + const InT* __restrict__ h_data, const IndexT* __restrict__ h_indices, + const IndptrT* __restrict__ d_indptr_full, + const int* __restrict__ d_row_ids, const int* __restrict__ d_out_indptr, + const int* __restrict__ d_stats_codes, int fixed_slot, + float* __restrict__ d_out_data_f32, int* __restrict__ d_out_indices, + double* __restrict__ group_sums, double* __restrict__ group_sq_sums, + double* __restrict__ group_nnz, int n_target_rows, int n_cols, + int n_groups_stats, bool compute_sq_sums, bool compute_nnz) { + int r = blockIdx.x; + if (r >= n_target_rows) return; + int src_row = d_row_ids[r]; + IndptrT rs = d_indptr_full[src_row]; + IndptrT re = d_indptr_full[src_row + 1]; + int row_nnz = (int)(re - rs); + int ds = d_out_indptr[r]; + int slot = (d_stats_codes != nullptr) ? d_stats_codes[r] : fixed_slot; + bool accumulate = (slot >= 0 && slot < n_groups_stats); + for (int i = threadIdx.x; i < row_nnz; i += blockDim.x) { + InT v_in = h_data[rs + i]; + int c = (int)h_indices[rs + i]; + double v = (double)v_in; + d_out_data_f32[ds + i] = (float)v_in; + d_out_indices[ds + i] = c; + if (accumulate) { + atomicAdd(&group_sums[(size_t)slot * n_cols + c], v); + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)slot * n_cols + c], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)slot * n_cols + c], 1.0); + } + } + } +} + +/** Fill linear segment offsets [0, stride, 2*stride, ...] on device. + * Runs on the supplied stream so it doesn't serialize multi-stream pipelines. + */ +static inline void upload_linear_offsets(int* d_offsets, int n_segments, + int stride, cudaStream_t stream) { + int count = n_segments + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + fill_linear_offsets_kernel<<>>( + d_offsets, n_segments, stride); + CUDA_CHECK_LAST_ERROR(fill_linear_offsets_kernel); +} + +// ============================================================================ +// CSR → dense F-order extraction (templated on data type) +// ============================================================================ + +template +__global__ void csr_extract_dense_kernel(const T* __restrict__ data, + const int* __restrict__ indices, + const int* __restrict__ indptr, + const int* __restrict__ row_ids, + T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_target) return; + + int row = row_ids[tid]; + int rs = indptr[row]; + int re = indptr[row + 1]; + + int lo = rs, hi = re; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + + for (int p = lo; p < re; ++p) { + int c = indices[p]; + if (c >= col_stop) break; + out[(long long)(c - col_start) * n_target + tid] = data[p]; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu new file mode 100644 index 00000000..d2bb63be --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo.cu @@ -0,0 +1,22 @@ +#include +#include + +#include + +#include "../nb_types.h" +#include "wilcoxon_common.cuh" +#include "kernels_wilcoxon.cuh" +#include "kernels_wilcoxon_ovo.cuh" + +using namespace nb::literals; + +#include "wilcoxon_ovo_kernels.cuh" +#include "wilcoxon_ovo_device_dense.cuh" +#include "wilcoxon_ovo_device_sparse.cuh" +#include "wilcoxon_ovo_host_sparse.cuh" +#include "wilcoxon_ovo_host_dense.cuh" +#include "wilcoxon_ovo_bindings.cuh" + +NB_MODULE(_wilcoxon_ovo_cuda, m) { + REGISTER_GPU_BINDINGS(register_bindings, m); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_bindings.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_bindings.cuh new file mode 100644 index 00000000..a96e2a4e --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_bindings.cuh @@ -0,0 +1,264 @@ +#pragma once + +template +void register_bindings(nb::module_& m) { + m.doc() = "CUDA kernels for Wilcoxon rank-sum test (OVO)"; + + // ---- Utility bindings (CUB sort, CSR extraction) ---- + + m.def("get_seg_sort_temp_bytes", &get_seg_sort_temp_bytes, "n_items"_a, + "n_segments"_a); + + m.def( + "segmented_sort", + [](gpu_array_c keys_in, + gpu_array_c keys_out, + gpu_array_c offsets, + gpu_array_c cub_temp, int n_items, int n_segments, + std::uintptr_t stream) { + size_t temp_bytes = cub_temp.size(); + cub::DeviceSegmentedRadixSort::SortKeys( + cub_temp.data(), temp_bytes, keys_in.data(), keys_out.data(), + n_items, n_segments, offsets.data(), offsets.data() + 1, 0, 32, + (cudaStream_t)stream); + CUDA_CHECK_LAST_ERROR(DeviceSegmentedRadixSort); + }, + "keys_in"_a, "keys_out"_a, "offsets"_a, "cub_temp"_a, nb::kw_only(), + "n_items"_a, "n_segments"_a, "stream"_a = 0); + + m.def( + "csr_extract_dense", + [](gpu_array_c data, + gpu_array_c indices, + gpu_array_c indptr, + gpu_array_c row_ids, + gpu_array_f out, int n_target, int col_start, + int col_stop, std::uintptr_t stream) { + int tpb = round_up_to_warp(n_target); + int blocks = (n_target + tpb - 1) / tpb; + csr_extract_dense_kernel<<>>( + data.data(), indices.data(), indptr.data(), row_ids.data(), + out.data(), n_target, col_start, col_stop); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + }, + "data"_a, "indices"_a, "indptr"_a, "row_ids"_a, "out"_a, nb::kw_only(), + "n_target"_a, "col_start"_a, "col_stop"_a, "stream"_a = 0); + + m.def( + "csr_extract_dense_f32", + [](gpu_array_c data, + gpu_array_c indices, + gpu_array_c indptr, + gpu_array_c row_ids, + gpu_array_f out, int n_target, int col_start, + int col_stop, std::uintptr_t stream) { + int tpb = round_up_to_warp(n_target); + int blocks = (n_target + tpb - 1) / tpb; + csr_extract_dense_kernel<<>>( + data.data(), indices.data(), indptr.data(), row_ids.data(), + out.data(), n_target, col_start, col_stop); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + }, + "data"_a, "indices"_a, "indptr"_a, "row_ids"_a, "out"_a, nb::kw_only(), + "n_target"_a, "col_start"_a, "col_stop"_a, "stream"_a = 0); + + // ---- Streaming pipelines ---- + + m.def( + "ovo_streaming_csr", + [](gpu_array_c csr_data, + gpu_array_c csr_indices, + gpu_array_c csr_indptr, + gpu_array_c ref_row_ids, + gpu_array_c grp_row_ids, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_csr_impl( + csr_data.data(), csr_indices.data(), csr_indptr.data(), + ref_row_ids.data(), grp_row_ids.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "ref_row_ids"_a, + "grp_row_ids"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovo_streaming_csc", + [](gpu_array_c csc_data, + gpu_array_c csc_indices, + gpu_array_c csc_indptr, + gpu_array_c ref_row_map, + gpu_array_c grp_row_map, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_csc_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + ref_row_map.data(), grp_row_map.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "ref_row_map"_a, + "grp_row_map"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovo_streaming", + [](gpu_array_f ref_sorted, + gpu_array_f grp_data, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_impl(ref_sorted.data(), grp_data.data(), + grp_offsets.data(), rank_sums.data(), + tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "ref_sorted"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, + "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + // ---- Host-streaming pipelines (host inputs, device outputs) ---- + +#define RSC_OVO_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_ref_row_map, \ + host_array h_grp_row_map, \ + host_array h_grp_offsets, \ + host_array h_stats_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ + int n_rows, int n_cols, int n_groups, int n_groups_stats, \ + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, \ + int sub_batch_cols) { \ + ovo_streaming_csc_host_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_ref_row_map.data(), h_grp_row_map.data(), \ + h_grp_offsets.data(), h_stats_codes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_ref, n_all_grp, \ + n_rows, n_cols, n_groups, n_groups_stats, compute_tie_corr, \ + compute_sq_sums, compute_nnz, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_map"_a, \ + "h_grp_row_map"_a, "h_grp_offsets"_a, "h_stats_codes"_a, \ + "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_ref"_a, \ + "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, \ + "n_groups_stats"_a, "compute_tie_corr"_a, "compute_sq_sums"_a = true, \ + "compute_nnz"_a = true, "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", float, int, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_i64", float, int, int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_idx64", float, int64_t, + int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_idx64_i64", float, int64_t, + int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64", double, int, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_i64", double, int, + int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_idx64", double, + int64_t, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVO_CSC_HOST_BINDING + +#define RSC_OVO_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_ref_row_ids, \ + host_array h_grp_row_ids, \ + host_array h_grp_offsets, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_full_rows, \ + int n_ref, int n_all_grp, int n_cols, int n_test, \ + int n_groups_stats, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, int sub_batch_cols) { \ + ovo_streaming_csr_host_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), n_full_rows, \ + h_ref_row_ids.data(), n_ref, h_grp_row_ids.data(), \ + h_grp_offsets.data(), n_all_grp, n_test, d_rank_sums.data(), \ + d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_cols, \ + n_groups_stats, compute_tie_corr, compute_sq_sums, \ + compute_nnz, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_ids"_a, \ + "h_grp_row_ids"_a, "h_grp_offsets"_a, "d_rank_sums"_a, "d_tie_corr"_a, \ + "d_group_sums"_a, "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), \ + "n_full_rows"_a, "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_test"_a, \ + "n_groups_stats"_a, "compute_tie_corr"_a, "compute_sq_sums"_a = true, \ + "compute_nnz"_a = true, "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", float, int, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_i64", float, int, int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_idx64", float, int64_t, + int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_idx64_i64", float, int64_t, + int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64", double, int, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_i64", double, int, + int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_idx64", double, + int64_t, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVO_CSR_HOST_BINDING + +#define RSC_OVO_DENSE_HOST_BINDING(NAME, InT) \ + m.def( \ + NAME, \ + [](host_array_2d h_block, \ + host_array h_ref_row_ids, \ + host_array h_grp_row_ids, \ + host_array h_grp_offsets, \ + host_array h_stats_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ + int n_rows, int n_cols, int n_groups, int n_groups_stats, \ + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, \ + int sub_batch_cols) { \ + ovo_streaming_dense_host_impl( \ + h_block.data(), h_ref_row_ids.data(), h_grp_row_ids.data(), \ + h_grp_offsets.data(), h_stats_codes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_ref, n_all_grp, \ + n_rows, n_cols, n_groups, n_groups_stats, compute_tie_corr, \ + compute_sq_sums, compute_nnz, sub_batch_cols); \ + }, \ + "h_block"_a, "h_ref_row_ids"_a, "h_grp_row_ids"_a, "h_grp_offsets"_a, \ + "h_stats_codes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_ref"_a, \ + "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, \ + "n_groups_stats"_a, "compute_tie_corr"_a, "compute_sq_sums"_a = true, \ + "compute_nnz"_a = true, "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_DENSE_HOST_BINDING("ovo_streaming_dense_host", float); + RSC_OVO_DENSE_HOST_BINDING("ovo_streaming_dense_host_f64", double); +#undef RSC_OVO_DENSE_HOST_BINDING +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_dense.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_dense.cuh new file mode 100644 index 00000000..261f95b4 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_dense.cuh @@ -0,0 +1,190 @@ +#pragma once + +/** + * Streaming OVO pipeline. + * + * Takes pre-sorted reference (float32 F-order), unsorted group data (float32 + * F-order with group offsets), and produces rank_sums + tie_corr. + * + * For each sub-batch of columns: + * 1. CUB segmented sort-keys of group data (one segment per group per col) + * 2. batched_rank_sums_presorted_kernel (binary search in sorted ref) + */ +static void ovo_streaming_impl(const float* ref_sorted, const float* grp_data, + const int* grp_offsets, double* rank_sums, + double* tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch ---- + std::vector h_off(n_groups + 1); + cudaMemcpy(h_off.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_off.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = + make_sort_group_ids(h_off.data(), n_groups, TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + size_t cub_temp_bytes = 0; + if (needs_tier3) { + int max_n_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_temp_bytes, fk, fk, (int)sub_grp_items, max_n_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Allocate per-stream buffers via RMM pool + RmmPool pool; + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + struct StreamBuf { + float* grp_sorted; + int* seg_offsets; + int* seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_n_seg = n_sort_groups * sub_batch_cols; + bufs[s].seg_offsets = pool.alloc(max_n_seg); + bufs[s].seg_ends = pool.alloc(max_n_seg); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].seg_offsets = nullptr; + bufs[s].seg_ends = nullptr; + bufs[s].cub_temp = nullptr; + } + bufs[s].ref_tie_sums = (t1.any_tier2 && compute_tie_corr) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_grp_items = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + const float* grp_in = grp_data + (long long)col * n_all_grp; + const float* ref_sub = ref_sorted + (long long)col * n_ref; + + int skip_le = 0; + if (t1.use_tier0) { + launch_tier0(ref_sub, grp_in, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (t1.any_tier2) { + if (compute_tie_corr) { + launch_ref_tie_sums(ref_sub, buf.ref_tie_sums, n_ref, sb_cols, + stream); + } + launch_tier2_medium(ref_sub, grp_in, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, + n_all_grp, sb_cols, n_groups, compute_tie_corr, + TIER0_GROUP_THRESHOLD, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + ref_sub, grp_in, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, t1.padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_n_seg = n_sort_groups * sb_cols; + { + int blk = (sb_n_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + grp_offsets, d_sort_group_ids, buf.seg_offsets, + buf.seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, grp_in, buf.grp_sorted, sb_grp_items, + sb_n_seg, buf.seg_offsets, buf.seg_ends, BEGIN_BIT, END_BIT, + stream); + CUDA_CHECK_LAST_ERROR(DeviceSegmentedRadixSort); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + // Scatter sub-batch results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh new file mode 100644 index 00000000..9f6d89e2 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -0,0 +1,474 @@ +#pragma once + +/** + * CSR-direct OVO streaming pipeline. + * + * One C++ call does everything: extract rows from CSR → sort → rank. + * Per sub-batch of columns: + * 1. Extract ref rows → dense f32 → CUB sort + * 2. Extract grp rows → dense f32 → CUB sort (segmented by group) + * 3. Binary search rank sums + * Only ~(n_ref + n_all_grp) × sub_batch × 4B on GPU at a time. + */ +static void ovo_streaming_csr_impl( + const float* csr_data, const int* csr_indices, const int* csr_indptr, + const int* ref_row_ids, const int* grp_row_ids, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch: read group offsets to determine max group size ---- + std::vector h_offsets(n_groups + 1); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, + TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp for ref sort (always needed) + grp sort (Tier 3 only) + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Allocate per-stream buffers via RMM pool + RmmPool pool; + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + struct StreamBuf { + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].ref_tie_sums = (t1.any_tier2 && compute_tie_corr) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_seg); + bufs[s].grp_seg_ends = pool.alloc(max_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_extract = round_up_to_warp(std::max(n_ref, n_all_grp)); + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = n_ref * sb_cols; + int sb_grp_items_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- Extract + sort ref (always CUB) ---- + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_items_actual * sizeof(float), + stream); + { + int blk = (n_ref + tpb_extract - 1) / tpb_extract; + csr_extract_dense_kernel<<>>( + csr_data, csr_indices, csr_indptr, ref_row_ids, buf.ref_dense, + n_ref, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + } + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Extract grp rows ---- + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), + stream); + { + int blk = (n_all_grp + tpb_extract - 1) / tpb_extract; + csr_extract_dense_kernel<<>>( + csr_data, csr_indices, csr_indptr, grp_row_ids, buf.grp_dense, + n_all_grp, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + } + + // Tier 0 handles groups ≤ TIER0_GROUP_THRESHOLD; Tier 1/3 handle + // the rest. Since each group owns its own rank_sums / tie_corr + // row, the two kernels' writes interlace without conflict. + int skip_le = 0; + if (t1.use_tier0) { + launch_tier0(buf.ref_sorted, buf.grp_dense, grp_offsets, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (t1.any_tier2) { + if (compute_tie_corr) { + launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, + sb_cols, stream); + } + launch_tier2_medium( + buf.ref_sorted, buf.grp_dense, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, TIER0_GROUP_THRESHOLD, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + // ---- Tier 1: fused smem sort + binary search ---- + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + // ---- Tier 3: CUB segmented sort + binary search ---- + int sb_grp_seg = n_sort_groups * sb_cols; + { + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, grp_offsets, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + // ---- Scatter to global output ---- + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * CSC-direct OVO streaming pipeline. + * + * Like the CSR variant but extracts rows via a row-lookup map, avoiding + * CSC→CSR conversion. row_map_ref[row] = output index in ref block (-1 if + * not a ref row); likewise for row_map_grp. + */ +static void ovo_streaming_csc_impl( + const float* csc_data, const int* csc_indices, const int* csc_indptr, + const int* ref_row_map, const int* grp_row_map, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch ---- + std::vector h_offsets(n_groups + 1); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, + TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + struct StreamBuf { + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].ref_tie_sums = (t1.any_tier2 && compute_tie_corr) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); + bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = n_ref * sb_cols; + int sb_grp_items_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- Extract ref from CSC via row_map, then sort ---- + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_items_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + csc_data, csc_indices, csc_indptr, ref_row_map, buf.ref_dense, + n_ref, col); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Extract grp from CSC via row_map ---- + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + csc_data, csc_indices, csc_indptr, grp_row_map, buf.grp_dense, + n_all_grp, col); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + + int skip_le = 0; + if (t1.use_tier0) { + launch_tier0(buf.ref_sorted, buf.grp_dense, grp_offsets, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (t1.any_tier2) { + if (compute_tie_corr) { + launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, + sb_cols, stream); + } + launch_tier2_medium( + buf.ref_sorted, buf.grp_dense, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, TIER0_GROUP_THRESHOLD, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = n_sort_groups * sb_cols; + { + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, grp_offsets, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + // ---- Scatter to global output ---- + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_dense.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_dense.cuh new file mode 100644 index 00000000..458d6667 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_dense.cuh @@ -0,0 +1,313 @@ +#pragma once + +/** + * Gather specific rows from a dense F-order block into a smaller dense block. + * Grid: (n_cols,), Block: 256. + * row_ids[i] = original row index → output row i. + */ +__global__ void dense_gather_rows_kernel(const float* __restrict__ in, + const int* __restrict__ row_ids, + float* __restrict__ out, int n_rows_in, + int n_target, int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + const float* in_col = in + (long long)col * n_rows_in; + float* out_col = out + (long long)col * n_target; + for (int i = threadIdx.x; i < n_target; i += blockDim.x) { + out_col[i] = in_col[row_ids[i]]; + } +} + +/** + * Host-streaming dense OVO pipeline. + * + * Dense F-order float32 lives on host. Sub-batches of columns are H2D + * transferred, then ref/grp rows are gathered, sorted, and ranked. + * Results D2H per sub-batch. + */ +template +static void ovo_streaming_dense_host_impl( + const InT* h_block, const int* h_ref_row_ids, const int* h_grp_row_ids, + const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_ref, int n_all_grp, int n_rows, int n_cols, + int n_groups, int n_groups_stats, bool compute_tie_corr, + bool compute_sq_sums, bool compute_nnz, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch from host offsets ---- + auto t1 = make_tier1_config(h_grp_offsets, n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = + make_sort_group_ids(h_grp_offsets, n_groups, TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_dense = (size_t)n_rows * sub_batch_cols; + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + + // GPU copies of row_ids + group offsets + stats codes (uploaded once) + int* d_ref_row_ids = pool.alloc(n_ref); + int* d_grp_row_ids = pool.alloc(n_all_grp); + int* d_grp_offsets = pool.alloc(n_groups + 1); + int* d_stats_codes = pool.alloc(n_rows); + int* d_sort_group_ids = nullptr; + cudaMemcpy(d_ref_row_ids, h_ref_row_ids, n_ref * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_ids, h_grp_row_ids, n_all_grp * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + InT* d_block_orig; + float* d_block_f32; + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_block_orig = pool.alloc(sub_dense); + bufs[s].d_block_f32 = pool.alloc(sub_dense); + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].ref_tie_sums = (t1.any_tier2 && compute_tie_corr) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups_stats * sub_batch_cols); + bufs[s].d_group_sq_sums = pool.alloc( + compute_sq_sums ? (size_t)n_groups_stats * sub_batch_cols : 1); + bufs[s].d_group_nnz = pool.alloc( + compute_nnz ? (size_t)n_groups_stats * sub_batch_cols : 1); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); + bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups_stats, compute_sq_sums, compute_nnz, cast_use_gmem); + + // Pin only the host input; outputs live on the device. + HostRegisterGuard _pin_block(const_cast(h_block), + (size_t)n_rows * n_cols * sizeof(InT)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_dense = n_rows * sb_cols; + int sb_ref_actual = n_ref * sb_cols; + int sb_grp_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- H2D: dense column sub-batch (F-order, native dtype) ---- + cudaMemcpyAsync(buf.d_block_orig, h_block + (long long)col * n_rows, + sb_dense * sizeof(InT), cudaMemcpyHostToDevice, stream); + + // ---- Cast to float32 for sort + accumulate stats in float64 ---- + launch_ovr_cast_and_accumulate_dense( + buf.d_block_orig, buf.d_block_f32, d_stats_codes, buf.d_group_sums, + buf.d_group_sq_sums, buf.d_group_nnz, n_rows, sb_cols, + n_groups_stats, compute_sq_sums, compute_nnz, UTIL_BLOCK_SIZE, + smem_cast, cast_use_gmem, stream); + + // ---- Gather ref rows, sort ---- + dense_gather_rows_kernel<<>>( + buf.d_block_f32, d_ref_row_ids, buf.ref_dense, n_rows, n_ref, + sb_cols); + CUDA_CHECK_LAST_ERROR(dense_gather_rows_kernel); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Gather grp rows ---- + dense_gather_rows_kernel<<>>( + buf.d_block_f32, d_grp_row_ids, buf.grp_dense, n_rows, n_all_grp, + sb_cols); + CUDA_CHECK_LAST_ERROR(dense_gather_rows_kernel); + + // ---- Tier dispatch: sort grp + rank ---- + int skip_le = 0; + if (t1.use_tier0) { + launch_tier0(buf.ref_sorted, buf.grp_dense, d_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (t1.any_tier2) { + if (compute_tie_corr) { + launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, + sb_cols, stream); + } + launch_tier2_medium( + buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.ref_tie_sums, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, TIER0_GROUP_THRESHOLD, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = n_sort_groups * sb_cols; + { + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + d_grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, d_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + // ---- D2D: scatter sub-batch results into caller's GPU buffers ---- + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(d_tie_corr + col, n_cols * sizeof(double), + buf.d_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Nanobind module +// ============================================================================ diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh new file mode 100644 index 00000000..32721a41 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -0,0 +1,848 @@ +#pragma once + +/** + * Host-streaming CSC OVO pipeline. + * + * CSC arrays live on host. Only the sparse data for each sub-batch of + * columns is transferred to GPU. Row maps + group offsets are uploaded once. + * Results are written back to host per sub-batch. + */ +template +static void ovo_streaming_csc_host_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_ref_row_map, const int* h_grp_row_map, + const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_ref, int n_all_grp, int n_rows, int n_cols, + int n_groups, int n_groups_stats, bool compute_tie_corr, + bool compute_sq_sums, bool compute_nnz, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch from host offsets ---- + auto t1 = make_tier1_config(h_grp_offsets, n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = + make_sort_group_ids(h_grp_offsets, n_groups, TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + // Max nnz across any sub-batch for sparse transfer buffer sizing + size_t max_nnz = 0; + for (int c = 0; c < n_cols; c += sub_batch_cols) { + int sb = std::min(sub_batch_cols, n_cols - c); + size_t nnz = (size_t)(h_indptr[c + sb] - h_indptr[c]); + if (nnz > max_nnz) max_nnz = nnz; + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb = std::min(sub_batch_cols, n_cols - col_start); + IndptrT ptr_start = h_indptr[col_start]; + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i <= sb; i++) + off[i] = (int)(h_indptr[col_start + i] - ptr_start); + } + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // GPU copies of row maps + group offsets + stats codes (uploaded once) + int* d_ref_row_map = pool.alloc(n_rows); + int* d_grp_row_map = pool.alloc(n_rows); + int* d_grp_offsets = pool.alloc(n_groups + 1); + int* d_stats_codes = pool.alloc(n_rows); + int* d_sort_group_ids = nullptr; + cudaMemcpy(d_ref_row_map, h_ref_row_map, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_map, h_grp_row_map, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + InT* d_sparse_data_orig; + float* d_sparse_data_f32; + IndexT* d_sparse_indices; + int* d_indptr; + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); + bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_indptr = pool.alloc(sub_batch_cols + 1); + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].ref_tie_sums = (t1.any_tier2 && compute_tie_corr) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups_stats * sub_batch_cols); + bufs[s].d_group_sq_sums = pool.alloc( + compute_sq_sums ? (size_t)n_groups_stats * sub_batch_cols : 1); + bufs[s].d_group_nnz = pool.alloc( + compute_nnz ? (size_t)n_groups_stats * sub_batch_cols : 1); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); + bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups_stats, compute_sq_sums, compute_nnz, cast_use_gmem); + + // Pin only the sparse input arrays; outputs live on the device. + size_t total_nnz = (size_t)h_indptr[n_cols]; + HostRegisterGuard _pin_data(const_cast(h_data), + total_nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(IndexT)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_actual = n_ref * sb_cols; + int sb_grp_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- H2D: sparse data for this column range (native dtype) ---- + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + size_t nnz = (size_t)(ptr_end - ptr_start); + cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, + nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, + nnz * sizeof(IndexT), cudaMemcpyHostToDevice, stream); + int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.d_indptr, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // ---- Cast to float32 for sort + accumulate stats in float64 ---- + launch_ovr_cast_and_accumulate_sparse( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, + buf.d_indptr, d_stats_codes, buf.d_group_sums, buf.d_group_sq_sums, + buf.d_group_nnz, sb_cols, n_groups_stats, compute_sq_sums, + compute_nnz, UTIL_BLOCK_SIZE, smem_cast, cast_use_gmem, stream); + + // ---- Extract ref from CSC via row_map, sort ---- + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, + d_ref_row_map, buf.ref_dense, n_ref, 0); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Extract grp from CSC via row_map ---- + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, + d_grp_row_map, buf.grp_dense, n_all_grp, 0); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + + // ---- Tier dispatch: sort grp + rank ---- + int skip_le = 0; + if (t1.use_tier0) { + launch_tier0(buf.ref_sorted, buf.grp_dense, d_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (t1.any_tier2) { + if (compute_tie_corr) { + launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, + sb_cols, stream); + } + launch_tier2_medium( + buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.ref_tie_sums, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, TIER0_GROUP_THRESHOLD, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = n_sort_groups * sb_cols; + { + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + d_grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, d_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + // ---- D2D: scatter sub-batch results into caller's GPU buffers ---- + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(d_tie_corr + col, n_cols * sizeof(double), + buf.d_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * Host CSR OVO pipeline — zero-copy mapped full-CSR with GPU-side row gather. + * + * Setup: pin the full host CSR with cudaHostRegisterMapped, upload the full + * indptr (small) + row_ids + pre-computed compacted indptrs. Each pack + * gathers only its rows over PCIe via a UVA kernel — the full matrix is never + * transferred to GPU. + * + * Phase 1 (Ref): fused gather + cast + stats over ref rows; segmented sort + * to d_ref_sorted (cached for the whole run). + * Phase 2 (per pack, round-robin across N_STREAMS): + * 1. rebase per-pack output indptr from the pre-uploaded global compacted + * indptr. + * 2. rebase per-pack group offsets + build per-row stats codes. + * 3. csr_gather_cast_accumulate_mapped_kernel — one PCIe pass, writes + * compacted f32 data + indices and accumulates per-group stats. + * 4. Per sub-batch: extract dense → sort → rank vs ref_sorted → scatter. + * + * Memory: d_ref_sorted (n_ref × n_cols × 4B) + N_STREAMS pack buffers sized + * for max_pack_rows × sb_cols (dense) and max_pack_nnz (compacted CSR). + * Full CSR stays on host (pinned-mapped). + */ +template +static void ovo_streaming_csr_host_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + int n_full_rows, const int* h_ref_row_ids, int n_ref, + const int* h_grp_row_ids, const int* h_grp_offsets, int n_all_grp, + int n_test, double* d_rank_sums, double* d_tie_corr, double* d_group_sums, + double* d_group_sq_sums, double* d_group_nnz, int n_cols, + int n_groups_stats, bool compute_tie_corr, bool compute_sq_sums, + bool compute_nnz, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_test == 0 || n_all_grp == 0) return; + + // ---- Pre-compute compacted indptrs on host (O(n_ref + n_all_grp)) ---- + // Use IndptrT for the global compacted indptr because the grp side can + // exceed 2^31 nnz on very large / dense matrices. Ref always fits in + // int32 since n_ref × n_cols ≪ 2B; keeping int32 there matches the + // downstream CUB segmented-sort temp sizing. + std::vector h_ref_indptr_compact(n_ref + 1); + h_ref_indptr_compact[0] = 0; + for (int i = 0; i < n_ref; i++) { + int r = h_ref_row_ids[i]; + int nnz_i = (int)(h_indptr[r + 1] - h_indptr[r]); + h_ref_indptr_compact[i + 1] = h_ref_indptr_compact[i] + nnz_i; + } + int ref_nnz = h_ref_indptr_compact[n_ref]; + + // grp: compacted indptr over concatenated test-group rows (IndptrT). + std::vector h_grp_indptr_compact(n_all_grp + 1); + h_grp_indptr_compact[0] = 0; + for (int i = 0; i < n_all_grp; i++) { + int r = h_grp_row_ids[i]; + IndptrT nnz_i = h_indptr[r + 1] - h_indptr[r]; + h_grp_indptr_compact[i + 1] = h_grp_indptr_compact[i] + nnz_i; + } + + // ---- Build packs (same rule as grp_impl, but uses compacted indptr) ---- + struct Pack { + int first; + int end; + int n_rows; + size_t nnz; + int sb_cols; + }; + std::vector packs; + int max_pack_rows = 0; + size_t max_pack_nnz = 0; + int max_pack_K = 0; + int max_pack_items = 0; + int max_pack_sb_cols = sub_batch_cols; + { + int target_packs = N_STREAMS; + int target_rows = (n_all_grp + target_packs - 1) / target_packs; + if (target_rows < 1) target_rows = 1; + size_t budget_cap_rows = + GROUP_DENSE_BUDGET_ITEMS / (size_t)sub_batch_cols; + if ((size_t)target_rows > budget_cap_rows) + target_rows = (int)budget_cap_rows; + + int cur_first = 0; + int cur_rows = 0; + size_t cur_nnz = 0; + for (int g = 0; g < n_test; g++) { + int n_g = h_grp_offsets[g + 1] - h_grp_offsets[g]; + size_t nnz_g = (size_t)(h_grp_indptr_compact[h_grp_offsets[g + 1]] - + h_grp_indptr_compact[h_grp_offsets[g]]); + int new_rows = cur_rows + n_g; + bool can_add = (cur_rows == 0) || (new_rows <= target_rows); + if (!can_add) { + size_t sb_size = + std::min((size_t)n_cols, + GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); + if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; + packs.push_back( + {cur_first, g, cur_rows, cur_nnz, (int)sb_size}); + cur_first = g; + cur_rows = n_g; + cur_nnz = nnz_g; + } else { + cur_rows = new_rows; + cur_nnz += nnz_g; + } + } + if (cur_rows > 0) { + size_t sb_size = std::min( + (size_t)n_cols, GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); + if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; + packs.push_back( + {cur_first, n_test, cur_rows, cur_nnz, (int)sb_size}); + } + } + for (const Pack& pk : packs) { + int K = pk.end - pk.first; + if (pk.n_rows > max_pack_rows) max_pack_rows = pk.n_rows; + if (pk.nnz > max_pack_nnz) max_pack_nnz = pk.nnz; + if (K > max_pack_K) max_pack_K = K; + int pack_items = pk.n_rows * pk.sb_cols; + if (pack_items > max_pack_items) max_pack_items = pack_items; + if (pk.sb_cols > max_pack_sb_cols) max_pack_sb_cols = pk.sb_cols; + } + int max_group_rows = max_pack_rows; + size_t max_sub_items = (size_t)max_pack_items; + if (max_pack_rows == 0) return; + + RmmPool pool; + + // Zero stats outputs. + cudaMemsetAsync(d_group_sums, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + } + + // ---- Pin full host data + indices as MAPPED (zero-copy accessible) ---- + size_t full_nnz = (size_t)h_indptr[n_full_rows]; + HostRegisterGuard _pin_data(const_cast(h_data), + full_nnz * sizeof(InT), cudaHostRegisterMapped); + HostRegisterGuard _pin_indices(const_cast(h_indices), + full_nnz * sizeof(IndexT), + cudaHostRegisterMapped); + + // Get device-accessible pointers (UVA makes these equal to host ptrs on + // Linux x86-64, but the API is the safe/portable way). + InT* d_data_zc = nullptr; + IndexT* d_indices_zc = nullptr; + if (full_nnz > 0) { + cudaError_t e1 = cudaHostGetDevicePointer((void**)&d_data_zc, + const_cast(h_data), 0); + cudaError_t e2 = cudaHostGetDevicePointer( + (void**)&d_indices_zc, const_cast(h_indices), 0); + if (e1 != cudaSuccess || e2 != cudaSuccess) { + throw std::runtime_error( + std::string("cudaHostGetDevicePointer failed: ") + + cudaGetErrorString(e1 != cudaSuccess ? e1 : e2)); + } + } + + // ---- Upload full indptr (keep native IndptrT — can exceed int32) ---- + IndptrT* d_indptr_full = pool.alloc(n_full_rows + 1); + cudaMemcpy(d_indptr_full, h_indptr, (n_full_rows + 1) * sizeof(IndptrT), + cudaMemcpyHostToDevice); + + // ---- Upload row_ids + compacted indptrs + group boundaries ---- + int* d_ref_row_ids = pool.alloc(n_ref); + int* d_grp_row_ids = pool.alloc(n_all_grp); + IndptrT* d_grp_indptr_compact = pool.alloc(n_all_grp + 1); + int* d_grp_offsets_full = pool.alloc(n_test + 1); + cudaMemcpy(d_ref_row_ids, h_ref_row_ids, n_ref * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_ids, h_grp_row_ids, n_all_grp * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_indptr_compact, h_grp_indptr_compact.data(), + (n_all_grp + 1) * sizeof(IndptrT), cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets_full, h_grp_offsets, (n_test + 1) * sizeof(int), + cudaMemcpyHostToDevice); + + // ---- Phase 1: Ref setup (scoped scratch, ref_sorted persists) ---- + float* d_ref_sorted = pool.alloc((size_t)n_ref * n_cols); + { + rmm::device_buffer ref_data_f32_buf(ref_nnz * sizeof(float), + rmm::cuda_stream_default, pool.mr); + rmm::device_buffer ref_indices_buf(ref_nnz * sizeof(int), + rmm::cuda_stream_default, pool.mr); + rmm::device_buffer ref_indptr_buf((n_ref + 1) * sizeof(int), + rmm::cuda_stream_default, pool.mr); + rmm::device_buffer ref_dense_buf((size_t)n_ref * n_cols * sizeof(float), + rmm::cuda_stream_default, pool.mr); + rmm::device_buffer ref_seg_buf((n_cols + 1) * sizeof(int), + rmm::cuda_stream_default, pool.mr); + + float* d_ref_data_f32 = (float*)ref_data_f32_buf.data(); + int* d_ref_indices = (int*)ref_indices_buf.data(); + int* d_ref_indptr = (int*)ref_indptr_buf.data(); + float* d_ref_dense = (float*)ref_dense_buf.data(); + int* d_ref_seg = (int*)ref_seg_buf.data(); + + // Upload ref compacted indptr + cudaMemcpy(d_ref_indptr, h_ref_indptr_compact.data(), + (n_ref + 1) * sizeof(int), cudaMemcpyHostToDevice); + + // Fused gather + cast + stats for ref (fixed slot = n_test). One + // pass over PCIe, no intermediate native-dtype GPU buffer. + if (n_ref > 0 && ref_nnz > 0) { + csr_gather_cast_accumulate_mapped_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, d_ref_row_ids, + d_ref_indptr, /*d_stats_codes=*/nullptr, + /*fixed_slot=*/n_test, d_ref_data_f32, d_ref_indices, + d_group_sums, d_group_sq_sums, d_group_nnz, n_ref, n_cols, + n_groups_stats, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); + } + + // Extract ref dense (F-order) — identity row-ids + cudaMemsetAsync(d_ref_dense, 0, (size_t)n_ref * n_cols * sizeof(float)); + { + rmm::device_buffer ref_row_ids_buf( + n_ref * sizeof(int), rmm::cuda_stream_default, pool.mr); + int* d_ref_identity = (int*)ref_row_ids_buf.data(); + fill_linear_offsets_kernel<<<(n_ref + UTIL_BLOCK_SIZE - 1) / + UTIL_BLOCK_SIZE, + UTIL_BLOCK_SIZE>>>(d_ref_identity, + n_ref - 1, 1); + CUDA_CHECK_LAST_ERROR(fill_linear_offsets_kernel); + int tpb = round_up_to_warp(n_ref); + int blk = (n_ref + tpb - 1) / tpb; + csr_extract_dense_kernel + <<>>(d_ref_data_f32, d_ref_indices, d_ref_indptr, + d_ref_identity, d_ref_dense, n_ref, 0, n_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + } + + // Segmented sort ref_dense by column → ref_sorted + size_t ref_cub_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, ref_cub_bytes, fk, fk, (int)((size_t)n_ref * n_cols), + n_cols, doff, doff + 1, BEGIN_BIT, END_BIT); + } + rmm::device_buffer cub_temp_buf(ref_cub_bytes, rmm::cuda_stream_default, + pool.mr); + upload_linear_offsets(d_ref_seg, n_cols, n_ref, 0); + size_t temp = ref_cub_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + cub_temp_buf.data(), temp, d_ref_dense, d_ref_sorted, + (int)((size_t)n_ref * n_cols), n_cols, d_ref_seg, d_ref_seg + 1, + BEGIN_BIT, END_BIT); + cudaDeviceSynchronize(); + } // ref scratch drops here + + // Identity row-ids [0..max_group_rows-1] — shared across all packs. + int* d_identity_row_ids = pool.alloc(max_group_rows); + fill_linear_offsets_kernel<<<(max_group_rows + UTIL_BLOCK_SIZE - 1) / + UTIL_BLOCK_SIZE, + UTIL_BLOCK_SIZE>>>(d_identity_row_ids, + max_group_rows - 1, 1); + CUDA_CHECK_LAST_ERROR(fill_linear_offsets_kernel); + + // ---- Phase 2: Per-pack streaming ---- + auto t1 = make_tier1_config(h_grp_offsets, n_test); + bool may_need_cub = (t1.max_grp_size > TIER1_GROUP_THRESHOLD); + + constexpr int MAX_GROUP_STREAMS = 4; + int n_streams = MAX_GROUP_STREAMS; + if (n_test < n_streams) n_streams = n_test; + if (n_streams < 1) n_streams = 1; + if ((int)packs.size() < n_streams) n_streams = (int)packs.size(); + if (n_streams < 1) n_streams = 1; + + size_t cub_grp_bytes = 0; + if (may_need_cub && max_sub_items > 0) { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + int max_segments = max_pack_K * max_pack_sb_cols; + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)max_sub_items, max_segments, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + struct StreamBuf { + float* d_grp_data_f32; + int* d_grp_indices; + int* d_grp_indptr; + int* d_pack_grp_offsets; + int* d_pack_stats_codes; + float* d_grp_dense; + float* d_grp_sorted; + double* d_ref_tie_sums; + int* d_sort_group_ids; + int* d_grp_seg_offsets; + int* d_grp_seg_ends; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + }; + std::vector bufs(n_streams); + int max_pack_kernel_seg = max_pack_K * max_pack_sb_cols; + for (int s = 0; s < n_streams; s++) { + bufs[s].d_grp_data_f32 = pool.alloc(max_pack_nnz); + bufs[s].d_grp_indices = pool.alloc(max_pack_nnz); + bufs[s].d_grp_indptr = pool.alloc(max_pack_rows + 1); + bufs[s].d_pack_grp_offsets = pool.alloc(max_pack_K + 1); + bufs[s].d_pack_stats_codes = pool.alloc(max_pack_rows); + bufs[s].d_grp_dense = pool.alloc(max_sub_items); + bufs[s].d_ref_tie_sums = pool.alloc(max_pack_sb_cols); + bufs[s].d_rank_sums = + pool.alloc((size_t)max_pack_K * max_pack_sb_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)max_pack_K * max_pack_sb_cols); + if (may_need_cub) { + bufs[s].d_grp_sorted = pool.alloc(max_sub_items); + bufs[s].d_sort_group_ids = pool.alloc(max_pack_K); + bufs[s].d_grp_seg_offsets = pool.alloc(max_pack_kernel_seg); + bufs[s].d_grp_seg_ends = pool.alloc(max_pack_kernel_seg); + bufs[s].cub_temp = pool.alloc(cub_grp_bytes); + } else { + bufs[s].d_grp_sorted = nullptr; + bufs[s].d_sort_group_ids = nullptr; + bufs[s].d_grp_seg_offsets = nullptr; + bufs[s].d_grp_seg_ends = nullptr; + bufs[s].cub_temp = nullptr; + } + } + + cudaDeviceSynchronize(); // ensure Phase 1 done before Phase 2 streams + + for (int p = 0; p < (int)packs.size(); p++) { + const Pack& pack = packs[p]; + int K = pack.end - pack.first; + if (K == 0 || pack.n_rows == 0) continue; + Tier1Config pack_t1 = make_tier1_config(h_grp_offsets + pack.first, K); + int pack_tpb_rank = round_up_to_warp( + std::min(pack_t1.max_grp_size, MAX_THREADS_PER_BLOCK)); + bool pack_has_above_t2 = pack_t1.max_grp_size > TIER2_GROUP_THRESHOLD; + int pack_tier3_skip_le = + pack_has_above_t2 ? TIER2_GROUP_THRESHOLD : TIER0_GROUP_THRESHOLD; + std::vector h_sort_group_ids; + int pack_n_sort_groups = K; + if (pack_t1.any_above_t0 && !pack_t1.use_tier1) { + h_sort_group_ids = make_sort_group_ids(h_grp_offsets + pack.first, + K, pack_tier3_skip_le); + pack_n_sort_groups = (int)h_sort_group_ids.size(); + } + + int s = p % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + + if (pack_t1.any_above_t0 && !pack_t1.use_tier1) { + cudaMemcpyAsync(buf.d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice, stream); + } + + int row_start = h_grp_offsets[pack.first]; + int pack_rows = pack.n_rows; + int pack_sb = pack.sb_cols; + + // Rebase pack's output indptr from pre-uploaded global compacted indptr + // (IndptrT → int32: pack nnz is bounded by GROUP_DENSE_BUDGET so fits). + { + int count = pack_rows + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel + <<>>( + d_grp_indptr_compact, buf.d_grp_indptr, row_start, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Build per-pack group offsets on GPU (on this stream) — needed to + // compute stats codes before the fused gather kernel can run. + { + int count = K + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel<<>>( + d_grp_offsets_full, buf.d_pack_grp_offsets, pack.first, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Fill per-row stats codes for this pack + { + int blk = (pack_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + fill_pack_stats_codes_kernel<<>>( + buf.d_pack_grp_offsets, buf.d_pack_stats_codes, K, pack.first); + CUDA_CHECK_LAST_ERROR(fill_pack_stats_codes_kernel); + } + + // Fused gather + cast + stats for the pack. One pass over PCIe + // (reads mapped host via UVA), no intermediate native-dtype GPU + // buffer, writes f32 + indices + atomics. + if (pack.nnz > 0) { + csr_gather_cast_accumulate_mapped_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, + d_grp_row_ids + row_start, buf.d_grp_indptr, + buf.d_pack_stats_codes, /*fixed_slot=*/-1, + buf.d_grp_data_f32, buf.d_grp_indices, d_group_sums, + d_group_sq_sums, d_group_nnz, pack_rows, n_cols, + n_groups_stats, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); + } + + // Per col sub-batch + int col = 0; + while (col < n_cols) { + int sb_cols = std::min(pack_sb, n_cols - col); + int sb_items = pack_rows * sb_cols; + + cudaMemsetAsync(buf.d_grp_dense, 0, sb_items * sizeof(float), + stream); + { + int tpb = round_up_to_warp(pack_rows); + int blk = (pack_rows + tpb - 1) / tpb; + csr_extract_dense_kernel<<>>( + buf.d_grp_data_f32, buf.d_grp_indices, buf.d_grp_indptr, + d_identity_row_ids, buf.d_grp_dense, pack_rows, col, + col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + } + + const float* ref_sub = d_ref_sorted + (size_t)col * n_ref; + + int skip_le = 0; + if (pack_t1.use_tier0) { + launch_tier0(ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, pack_rows, + sb_cols, K, compute_tie_corr, stream); + if (pack_t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (pack_t1.any_above_t0) { + if (compute_tie_corr) { + launch_ref_tie_sums(ref_sub, buf.d_ref_tie_sums, n_ref, + sb_cols, stream); + } + launch_tier2_medium( + ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, n_ref, + pack_rows, sb_cols, K, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = + pack_has_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (pack_has_above_t2 && pack_t1.use_tier1) { + dim3 grid(sb_cols, K); + ovo_fused_sort_rank_kernel<<>>( + ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, pack_rows, sb_cols, + K, compute_tie_corr, pack_t1.padded_grp_size, + upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (pack_has_above_t2) { + int n_seg = pack_n_sort_groups * sb_cols; + { + int blk = (n_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<< + blk, UTIL_BLOCK_SIZE, 0, stream>>>( + buf.d_pack_grp_offsets, buf.d_sort_group_ids, + buf.d_grp_seg_offsets, buf.d_grp_seg_ends, pack_rows, + pack_n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR( + build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_grp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.d_grp_dense, buf.d_grp_sorted, + sb_items, n_seg, buf.d_grp_seg_offsets, + buf.d_grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + dim3 grid(sb_cols, K); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.d_grp_sorted, buf.d_pack_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, pack_rows, sb_cols, + K, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + + cudaMemcpy2DAsync(d_rank_sums + (size_t)pack.first * n_cols + col, + n_cols * sizeof(double), buf.d_rank_sums, + sb_cols * sizeof(double), + sb_cols * sizeof(double), K, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync( + d_tie_corr + (size_t)pack.first * n_cols + col, + n_cols * sizeof(double), buf.d_tie_corr, + sb_cols * sizeof(double), sb_cols * sizeof(double), K, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + } + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in ovo csr host streaming: ") + + cudaGetErrorString(err)); + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh new file mode 100644 index 00000000..d508b9c2 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -0,0 +1,165 @@ +#pragma once + +/** + * Build CUB segmented-sort ranges only for groups that Tier 3 will rank. + * Group ids are relative to grp_offsets, and ranges still point into the + * original dense group layout so the presorted rank kernel can read from the + * normal per-group positions. + */ +__global__ void build_tier3_seg_begin_end_offsets_kernel( + const int* __restrict__ grp_offsets, const int* __restrict__ group_ids, + int* __restrict__ begins, int* __restrict__ ends, int n_all_grp, + int n_sort_groups, int sb_cols) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = sb_cols * n_sort_groups; + if (idx >= total) return; + + int c = idx / n_sort_groups; + int local = idx % n_sort_groups; + int g = group_ids[local]; + int base = c * n_all_grp; + begins[idx] = base + grp_offsets[g]; + ends[idx] = base + grp_offsets[g + 1]; +} + +/** + * Extract specific rows from CSC into dense F-order, using a row lookup map. + * row_map[original_row] = output_row_index (or -1 to skip). + * One block per column, threads scatter matching nonzeros. + * Output must be pre-zeroed. + */ +template +__global__ void csc_extract_mapped_kernel(const float* __restrict__ data, + const IndexT* __restrict__ indices, + const int* __restrict__ indptr, + const int* __restrict__ row_map, + float* __restrict__ out, int n_target, + int col_start) { + int col_local = blockIdx.x; + int col = col_start + col_local; + + int start = indptr[col]; + int end = indptr[col + 1]; + + for (int p = start + threadIdx.x; p < end; p += blockDim.x) { + int out_row = row_map[(int)indices[p]]; + if (out_row >= 0) { + out[(long long)col_local * n_target + out_row] = data[p]; + } + } +} + +static size_t get_seg_sort_temp_bytes(int n_items, int n_segments) { + size_t bytes = 0; + auto* dk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys(nullptr, bytes, dk, dk, n_items, + n_segments, doff, doff + 1, 0, 32); + return bytes; +} + +/** + * Tier 1 dispatch: when the largest group fits in shared memory, a fused + * bitonic-sort + binary-search kernel handles the whole group per block. + * Otherwise we fall back to CUB segmented sort plus the pre-sorted rank + * kernel. This struct bundles the sizing knobs derived from the host-side + * group offsets so each streaming impl can drop a 15-line prep block. + */ +struct Tier1Config { + int max_grp_size = 0; + int min_grp_size = 0; + bool use_tier0 = + false; // any group fits in one warp (≤ TIER0_GROUP_THRESHOLD) + bool use_tier1 = + false; // any group needs > tier0 but fits in tier1 smem sort + bool any_above_t0 = + false; // at least one group exceeds TIER0_GROUP_THRESHOLD + bool any_tier2 = false; // any group needs Tier 2: (T0, T2] + bool any_above_t2 = + false; // at least one group exceeds TIER2_GROUP_THRESHOLD + int padded_grp_size = 0; + int tier1_tpb = 0; + size_t tier1_smem = 0; +}; + +static Tier1Config make_tier1_config(const int* h_grp_offsets, int n_groups) { + Tier1Config c; + c.min_grp_size = INT_MAX; + for (int g = 0; g < n_groups; g++) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (sz > c.max_grp_size) c.max_grp_size = sz; + if (sz < c.min_grp_size) c.min_grp_size = sz; + if (sz > TIER0_GROUP_THRESHOLD && sz <= TIER2_GROUP_THRESHOLD) { + c.any_tier2 = true; + } + if (sz > TIER2_GROUP_THRESHOLD) c.any_above_t2 = true; + } + if (n_groups == 0) c.min_grp_size = 0; + + // use_tier0: Tier 0 kernel is worth running (at least one group small + // enough to benefit from the warp path). + c.use_tier0 = (c.min_grp_size <= TIER0_GROUP_THRESHOLD); + // any_above_t0: at least one group needs a non-Tier-0 kernel. + c.any_above_t0 = (c.max_grp_size > TIER0_GROUP_THRESHOLD); + // use_tier1: the fused smem-sort fast path (for groups > T0 but ≤ T1). + c.use_tier1 = c.any_above_t0 && (c.max_grp_size <= TIER1_GROUP_THRESHOLD); + if (c.use_tier1) { + c.padded_grp_size = 1; + while (c.padded_grp_size < c.max_grp_size) c.padded_grp_size <<= 1; + c.tier1_tpb = std::min(c.padded_grp_size, MAX_THREADS_PER_BLOCK); + c.tier1_smem = (size_t)c.padded_grp_size * sizeof(float) + + WARP_REDUCE_BUF * sizeof(double); + } + return c; +} + +static std::vector make_sort_group_ids(const int* h_grp_offsets, + int n_groups, int skip_n_grp_le) { + std::vector ids; + ids.reserve(n_groups); + for (int g = 0; g < n_groups; ++g) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (skip_n_grp_le > 0 && sz <= skip_n_grp_le) continue; + ids.push_back(g); + } + return ids; +} + +// Tier 0 kernel launcher: 8 warps × 32 threads per block, one (col, group) +// pair per warp. grid.y covers ceil(K/8) pair rows. +static inline void launch_tier0(const float* ref_sorted, const float* grp_dense, + const int* grp_offsets, double* rank_sums, + double* tie_corr, int n_ref, int n_all_grp, + int sb_cols, int K, bool compute_tie_corr, + cudaStream_t stream) { + constexpr int WARPS_PER_BLOCK = 8; + dim3 grid(sb_cols, (K + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK); + ovo_warp_sort_rank_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, rank_sums, tie_corr, n_ref, + n_all_grp, sb_cols, K, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(ovo_warp_sort_rank_kernel); +} + +static inline void launch_ref_tie_sums(const float* ref_sorted, + double* ref_tie_sums, int n_ref, + int sb_cols, cudaStream_t stream) { + ref_tie_sum_kernel<<>>( + ref_sorted, ref_tie_sums, n_ref, sb_cols); + CUDA_CHECK_LAST_ERROR(ref_tie_sum_kernel); +} + +static inline void launch_tier2_medium( + const float* ref_sorted, const float* grp_dense, const int* grp_offsets, + const double* ref_tie_sums, double* rank_sums, double* tie_corr, int n_ref, + int n_all_grp, int sb_cols, int K, bool compute_tie_corr, int skip_n_grp_le, + cudaStream_t stream) { + constexpr int tpb = 256; + size_t smem = (size_t)TIER2_GROUP_THRESHOLD * sizeof(float) + + WARP_REDUCE_BUF * sizeof(double); + dim3 grid(sb_cols, K); + ovo_medium_unsorted_rank_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, + n_ref, n_all_grp, sb_cols, K, compute_tie_corr, skip_n_grp_le, + TIER2_GROUP_THRESHOLD); + CUDA_CHECK_LAST_ERROR(ovo_medium_unsorted_rank_kernel); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu new file mode 100644 index 00000000..e5661126 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr.cu @@ -0,0 +1,19 @@ +#include +#include + +#include + +#include "../nb_types.h" +#include "wilcoxon_common.cuh" +#include "kernels_wilcoxon.cuh" + +using namespace nb::literals; + +#include "wilcoxon_ovr_kernels.cuh" +#include "wilcoxon_ovr_dense.cuh" +#include "wilcoxon_ovr_sparse.cuh" +#include "wilcoxon_ovr_bindings.cuh" + +NB_MODULE(_wilcoxon_ovr_cuda, m) { + REGISTER_GPU_BINDINGS(register_bindings, m); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_bindings.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_bindings.cuh new file mode 100644 index 00000000..72ae1938 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_bindings.cuh @@ -0,0 +1,183 @@ +#pragma once + +// ============================================================================ +// Nanobind module +// ============================================================================ + +template +void register_bindings(nb::module_& m) { + m.doc() = "CUDA kernels for Wilcoxon rank-sum test (OVR)"; + + // ---- Streaming pipelines ---- + + m.def( + "ovr_streaming", + [](gpu_array_f block, + gpu_array_c group_codes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_streaming_impl(block.data(), group_codes.data(), + rank_sums.data(), tie_corr.data(), n_rows, + n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "block"_a, "group_codes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovr_sparse_csr", + [](gpu_array_c csr_data, + gpu_array_c csr_indices, + gpu_array_c csr_indptr, + gpu_array_c group_codes, + gpu_array_c group_sizes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_sparse_csr_streaming_impl( + csr_data.data(), csr_indices.data(), csr_indptr.data(), + group_codes.data(), group_sizes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "group_codes"_a, + "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovr_sparse_csc", + [](gpu_array_c csc_data, + gpu_array_c csc_indices, + gpu_array_c csc_indptr, + gpu_array_c group_codes, + gpu_array_c group_sizes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_sparse_csc_streaming_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + group_codes.data(), group_sizes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "group_codes"_a, + "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + // ---- Host-streaming pipelines (host inputs, device outputs) ---- + +#define RSC_OVR_SPARSE_CSC_HOST_BINDING(NAME, InT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + host_array h_group_sizes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, int sub_batch_cols) { \ + ovr_sparse_csc_host_streaming_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), h_group_sizes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ + n_groups, compute_tie_corr, compute_sq_sums, compute_nnz, \ + sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, \ + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64", float, int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64", double, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, + int64_t); +#undef RSC_OVR_SPARSE_CSC_HOST_BINDING + +#define RSC_OVR_SPARSE_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + host_array h_group_sizes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, int sub_batch_cols) { \ + ovr_sparse_csr_host_streaming_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), h_group_sizes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ + n_groups, compute_tie_corr, compute_sq_sums, compute_nnz, \ + sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, \ + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", float, int, int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_i64", float, int, + int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_idx64", float, int64_t, + int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_idx64_i64", float, + int64_t, int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64", double, int, + int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_i64", double, int, + int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_idx64", double, + int64_t, int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVR_SPARSE_CSR_HOST_BINDING + +#define RSC_OVR_DENSE_HOST_BINDING(NAME, InT) \ + m.def( \ + NAME, \ + [](host_array_2d h_block, \ + host_array h_group_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, int sub_batch_cols) { \ + ovr_streaming_dense_host_impl( \ + h_block.data(), h_group_codes.data(), d_rank_sums.data(), \ + d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ + n_groups, compute_tie_corr, compute_sq_sums, compute_nnz, \ + sub_batch_cols); \ + }, \ + "h_block"_a, "h_group_codes"_a, "d_rank_sums"_a, "d_tie_corr"_a, \ + "d_group_sums"_a, "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), \ + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_DENSE_HOST_BINDING("ovr_streaming_dense_host", float); + RSC_OVR_DENSE_HOST_BINDING("ovr_streaming_dense_host_f64", double); +#undef RSC_OVR_DENSE_HOST_BINDING +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_dense.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_dense.cuh new file mode 100644 index 00000000..2adb5b7b --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_dense.cuh @@ -0,0 +1,313 @@ +#pragma once + +/** + * Streaming OVR pipeline. + * + * Takes a dense F-order float32 block (n_rows, n_cols) + int32 group_codes, + * splits columns into sub-batches across multiple CUDA streams, and for each: + * 1. CUB SortPairs (float32 keys + int32 row indices) + * 2. Fused rank_sums_from_sorted_kernel + * + * Output: rank_sums (n_groups, n_cols) + tie_corr (n_cols), both float64. + */ +static void ovr_streaming_impl(const float* block, const int* group_codes, + double* rank_sums, double* tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_items = (size_t)n_rows * sub_batch_cols; + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + // Create streams + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Allocate per-stream buffers via RMM pool + RmmPool pool; + struct StreamBuf { + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + } + + int tpb_rank = round_up_to_warp(n_rows); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + + // Process sub-batches round-robin across streams + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = n_rows * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // Fill segment offsets + row indices + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); + fill_row_indices_kernel<<>>( + buf.vals_in, n_rows, sb_cols); + CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); + + // Sort: keys = block columns [col, col+sb_cols), already F-order + const float* keys_in = block + (long long)col * n_rows; + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, keys_in, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + + // Fused rank sums into sub-batch buffer + if (use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, + buf.sub_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, + use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); + + // Copy sub-batch results to global output (row-major scatter) + // rank_sums is (n_groups, n_cols) row-major: group g, col c → + // [g*n_cols+c] sub output is (n_groups, sb_cols): group g, local col lc + // → [g*sb_cols+lc] + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + batch_idx++; + } + + // Sync all streams + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * Host-streaming dense OVR pipeline. + * + * Templated on the host dtype (InT = float or double). Each sub-batch is + * copied to the device in its native dtype once; a fused cast+accumulate + * kernel writes a float32 view for the sort and accumulates per-group + * sum/sum-sq/nnz in float64 from the original-precision values. The + * existing sort + rank pipeline then runs on the float32 keys. + * + * Output pointers ({d_rank_sums, d_tie_corr, d_group_sums, d_group_sq_sums, + * d_group_nnz}) point to caller-provided CuPy memory of the full output + * shape; sub-batch kernels scatter directly into them via D2D. + * + * GPU memory stays at O(sub_batch * n_rows), now with a small extra + * InT-sized sub-batch buffer per stream. + */ +template +static void ovr_streaming_dense_host_impl( + const InT* h_block, const int* h_group_codes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_items = (size_t)n_rows * sub_batch_cols; + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Allocate per-stream buffers via RMM pool + RmmPool pool; + int* d_group_codes = pool.alloc(n_rows); + struct StreamBuf { + InT* d_block_orig; + float* d_block_f32; + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_block_orig = pool.alloc(sub_items); + bufs[s].d_block_f32 = pool.alloc(sub_items); + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sq_sums = + compute_sq_sums + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].d_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + // Group codes on GPU (transferred once) + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + + int tpb_rank = round_up_to_warp(n_rows); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + int tpb_cast = UTIL_BLOCK_SIZE; + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_sq_sums, + compute_nnz, cast_use_gmem); + + // Pin only the host input. Outputs live on the device (caller-owned). + HostRegisterGuard _pin_block(const_cast(h_block), + (size_t)n_rows * n_cols * sizeof(InT)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = n_rows * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // H2D: column sub-batch in native dtype (F-order → contiguous) + cudaMemcpyAsync(buf.d_block_orig, h_block + (long long)col * n_rows, + sb_items * sizeof(InT), cudaMemcpyHostToDevice, stream); + + // Cast to float32 for sort + accumulate stats in float64 + launch_ovr_cast_and_accumulate_dense( + buf.d_block_orig, buf.d_block_f32, d_group_codes, buf.d_group_sums, + buf.d_group_sq_sums, buf.d_group_nnz, n_rows, sb_cols, n_groups, + compute_sq_sums, compute_nnz, tpb_cast, smem_cast, cast_use_gmem, + stream); + + // Fill segment offsets + row indices + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); + fill_row_indices_kernel<<>>( + buf.vals_in, n_rows, sb_cols); + CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); + + // Sort + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.d_block_f32, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + + // Fused rank sums (stats already captured by the cast kernel) + if (use_gmem) { + cudaMemsetAsync(buf.d_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, d_group_codes, buf.d_rank_sums, + buf.d_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, + use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); + + // D2D: scatter sub-batch results into the caller's GPU buffers + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(d_tie_corr + col, buf.d_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh new file mode 100644 index 00000000..006002b9 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -0,0 +1,104 @@ +#pragma once + +/** Count nonzeros per column from CSR. One thread per row. */ +template +__global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, + int* __restrict__ col_counts, + int n_rows, int n_cols) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; + for (IndptrT p = rs; p < re; ++p) { + int c = (int)indices[p]; + if (c < n_cols) atomicAdd(&col_counts[c], 1); + } +} + +/** + * Scatter CSR nonzeros into CSC layout for columns [col_start, col_stop). + * write_pos[c - col_start] must be initialized to the prefix-sum offset + * for column c. Each thread atomically claims a unique destination slot. + */ +template +__global__ void csr_scatter_to_csc_kernel( + const InT* __restrict__ data, const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, int* __restrict__ write_pos, + InT* __restrict__ csc_vals, int* __restrict__ csc_row_idx, int n_rows, + int col_start, int col_stop) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; + // Binary search for col_start (overflow-safe midpoint) + IndptrT lo = rs, hi = re; + while (lo < hi) { + IndptrT m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + for (IndptrT p = lo; p < re; ++p) { + int c = (int)indices[p]; + if (c >= col_stop) break; + int dest = atomicAdd(&write_pos[c - col_start], 1); + csc_vals[dest] = data[p]; + csc_row_idx[dest] = row; + } +} + +/** + * Decide whether to use shared or global memory for OVR rank accumulators. + * Returns the smem size to request and sets use_gmem accordingly. + */ +static int query_max_smem_per_block() { + static int cached = -1; + if (cached < 0) { + int device; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&cached, cudaDevAttrMaxSharedMemoryPerBlock, + device); + } + return cached; +} + +static size_t ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(n_groups + 32) * sizeof(double); + if ((int)need <= query_max_smem_per_block()) { + use_gmem = false; + return need; + } + // Fall back to global memory accumulators; only need warp buf in smem + use_gmem = true; + return 32 * sizeof(double); +} + +/** + * Decide smem-vs-gmem for the sparse OVR rank kernel. Two accumulator + * arrays (grp_sums + grp_nz_count) of size n_groups each plus warp buf. + */ +static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(2 * n_groups + 32) * sizeof(double); + if ((int)need <= query_max_smem_per_block()) { + use_gmem = false; + return need; + } + use_gmem = true; + return 32 * sizeof(double); +} + +/** + * Fill sort values with row indices [0,1,...,n_rows-1] per column. + * Grid: (n_cols,), block: 256 threads. + */ +__global__ void fill_row_indices_kernel(int* __restrict__ vals, int n_rows, + int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + int* out = vals + (long long)col * n_rows; + for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { + out[i] = i; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh new file mode 100644 index 00000000..0f74a2c8 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -0,0 +1,861 @@ +#pragma once + +/** + * Sparse-aware host-streaming CSC OVR pipeline. + * + * Like ovr_streaming_csc_host_impl but sorts only stored nonzeros per column + * instead of extracting dense blocks. GPU memory is O(max_batch_nnz) instead + * of O(sub_batch * n_rows), and sort work is proportional to nnz, not n_rows. + */ +template +static void ovr_sparse_csc_host_streaming_impl( + const InT* h_data, const int* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + // Find max nnz across any sub-batch + size_t max_nnz = 0; + for (int col = 0; col < n_cols; col += sub_batch_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); + if (nnz > max_nnz) max_nnz = nnz; + } + + // CUB temp size for max_nnz items + size_t cub_temp_bytes = 0; + if (max_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); + struct StreamBuf { + InT* d_sparse_data_orig; + float* d_sparse_data_f32; + int* d_sparse_indices; + int* d_seg_offsets; + float* keys_out; + int* vals_out; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; + double* d_nz_scratch; // gmem-only; non-null when rank_use_gmem + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); + bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].keys_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sq_sums = + compute_sq_sums + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].d_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + // Transfer group codes + sizes once + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + // Pre-compute rebased per-batch offsets and upload once (avoids per-batch + // H2D copy from a transient host buffer). + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb = std::min(sub_batch_cols, n_cols - col_start); + IndptrT ptr_start = h_indptr[col_start]; + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i <= sb; i++) + off[i] = (int)(h_indptr[col_start + i] - ptr_start); + } + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_sq_sums, + compute_nnz, cast_use_gmem); + + // In gmem mode the sparse rank kernel accumulates into rank_sums directly + // and needs a per-stream nz_count scratch buffer sized (n_groups, sb_cols). + for (int s = 0; s < n_streams; s++) { + if (rank_use_gmem) { + bufs[s].d_nz_scratch = + pool.alloc((size_t)n_groups * sub_batch_cols); + } else { + bufs[s].d_nz_scratch = nullptr; + } + } + + // Pin only the host input arrays; outputs live on the device. + size_t total_nnz = (size_t)h_indptr[n_cols]; + HostRegisterGuard _pin_data(const_cast(h_data), + total_nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(int)); + + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = (int)(ptr_end - ptr_start); + + // H2D: transfer sparse data for this column range (native dtype) + if (batch_nnz > 0) { + cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, + (size_t)batch_nnz * sizeof(int), + cudaMemcpyHostToDevice, stream); + } + + // D2D: copy this batch's rebased offsets from the pre-uploaded buffer + int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.d_seg_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // Cast to float32 for sort + accumulate stats in float64 + launch_ovr_cast_and_accumulate_sparse( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, + buf.d_seg_offsets, d_group_codes, buf.d_group_sums, + buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, n_groups, + compute_sq_sums, compute_nnz, tpb, smem_cast, cast_use_gmem, + stream); + + // CUB sort only stored nonzeros (float32 keys) + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.d_sparse_data_f32, buf.keys_out, + buf.d_sparse_indices, buf.vals_out, batch_nnz, sb_cols, + buf.d_seg_offsets, buf.d_seg_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + // Sparse rank kernel (stats already captured above) + if (rank_use_gmem) { + cudaMemsetAsync(buf.d_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, + d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // D2D: scatter sub-batch results into caller's GPU buffers + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(d_tie_corr + col, buf.d_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse host CSC streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Sparse-aware host-streaming CSR OVR pipeline. +// ============================================================================ + +/** + * Host CSR variant of the sparse OVR stream. + * + * The CSR input stays in host memory. We count columns once on the CPU, then + * use mapped pinned CSR arrays for bounded per-column-batch CSR->CSC scatter + * on the GPU. This avoids both a full host->device sparse upload and any + * whole-matrix CSR->CSC conversion. + */ +template +static void ovr_sparse_csr_host_streaming_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + RmmPool pool; + size_t total_nnz = (size_t)h_indptr[n_rows]; + + // ---- Phase 0: CPU planning in native CSR order ---- + std::vector h_col_counts(n_cols, 0); + for (int row = 0; row < n_rows; row++) { + IndptrT rs = h_indptr[row]; + IndptrT re = h_indptr[row + 1]; + for (IndptrT p = rs; p < re; ++p) { + int c = (int)h_indices[p]; + if (c >= 0 && c < n_cols) h_col_counts[c]++; + } + } + + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + size_t max_batch_nnz = 0; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + std::vector h_batch_nnz(n_batches); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb_cols = std::min(sub_batch_cols, n_cols - col_start); + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i < sb_cols; i++) + off[i + 1] = off[i] + h_col_counts[col_start + i]; + h_batch_nnz[b] = (size_t)off[sb_cols]; + if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; + } + + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // ---- Phase 1: allocate per-stream bounded work buffers ---- + size_t cub_temp_bytes = 0; + if (max_batch_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_batch_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_sq_sums, + compute_nnz, cast_use_gmem); + + int n_streams = N_STREAMS; + if (n_batches < n_streams) n_streams = n_batches; + + size_t per_stream_bytes = + max_batch_nnz * (sizeof(InT) + sizeof(float) + 2 * sizeof(int)) + + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double) + + sub_batch_cols * sizeof(double); + if (compute_sq_sums) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + if (compute_nnz) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + if (rank_use_gmem) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + + size_t free_mem = 0, total_mem = 0; + cudaMemGetInfo(&free_mem, &total_mem); + constexpr double MEM_BUDGET_FRAC = 0.8; + size_t budget = (size_t)(free_mem * MEM_BUDGET_FRAC); + while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + n_streams--; + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Pin the source CSR arrays as mapped memory. The scatter kernel reads + // only the requested column window from each row. + HostRegisterGuard pin_data; + HostRegisterGuard pin_indices; + InT* d_data_zc = nullptr; + IndexT* d_indices_zc = nullptr; + if (total_nnz > 0) { + pin_data = + HostRegisterGuard(const_cast(h_data), total_nnz * sizeof(InT), + cudaHostRegisterMapped); + pin_indices = HostRegisterGuard(const_cast(h_indices), + total_nnz * sizeof(IndexT), + cudaHostRegisterMapped); + cudaError_t e1 = cudaHostGetDevicePointer((void**)&d_data_zc, + const_cast(h_data), 0); + cudaError_t e2 = cudaHostGetDevicePointer( + (void**)&d_indices_zc, const_cast(h_indices), 0); + if (e1 != cudaSuccess || e2 != cudaSuccess) { + throw std::runtime_error( + std::string("cudaHostGetDevicePointer failed: ") + + cudaGetErrorString(e1 != cudaSuccess ? e1 : e2)); + } + } + + IndptrT* d_indptr_full = pool.alloc(n_rows + 1); + cudaMemcpy(d_indptr_full, h_indptr, (n_rows + 1) * sizeof(IndptrT), + cudaMemcpyHostToDevice); + + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + int scatter_blocks = (n_rows + tpb - 1) / tpb; + + struct StreamBuf { + int* col_offsets; + int* write_pos; + InT* csc_vals_orig; + float* csc_vals_f32; + int* csc_row_idx; + float* keys_out; + int* vals_out; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* sub_group_sums; + double* sub_group_sq_sums; + double* sub_group_nnz; + double* d_nz_scratch; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].col_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].write_pos = pool.alloc(sub_batch_cols); + bufs[s].csc_vals_orig = pool.alloc(max_batch_nnz); + bufs[s].csc_vals_f32 = pool.alloc(max_batch_nnz); + bufs[s].csc_row_idx = pool.alloc(max_batch_nnz); + bufs[s].keys_out = pool.alloc(max_batch_nnz); + bufs[s].vals_out = pool.alloc(max_batch_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].sub_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_group_sq_sums = + compute_sq_sums + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].sub_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + // ---- Phase 2: bounded CSR->CSC scatter + GPU rank batches ---- + int col = 0; + for (int b = 0; b < n_batches; b++) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = b % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + int batch_nnz = (int)h_batch_nnz[b]; + + int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + csr_scatter_to_csc_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, buf.write_pos, + buf.csc_vals_orig, buf.csc_row_idx, n_rows, col, + col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + } + + launch_ovr_cast_and_accumulate_sparse( + buf.csc_vals_orig, buf.csc_vals_f32, buf.csc_row_idx, + buf.col_offsets, d_group_codes, buf.sub_group_sums, + buf.sub_group_sq_sums, buf.sub_group_nnz, sb_cols, n_groups, + compute_sq_sums, compute_nnz, tpb, smem_cast, cast_use_gmem, + stream); + + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.csc_vals_f32, buf.keys_out, + buf.csc_row_idx, buf.vals_out, batch_nnz, sb_cols, + buf.col_offsets, buf.col_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.col_offsets, d_group_codes, + d_group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, + buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, + rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(d_tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.sub_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.sub_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.sub_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse host CSR streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Sparse-aware CSC OVR streaming (sort only stored nonzeros) +// ============================================================================ + +static void ovr_sparse_csc_streaming_impl( + const float* csc_data, const int* csc_indices, const int* csc_indptr, + const int* group_codes, const double* group_sizes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + // Read indptr to host for batch planning + std::vector h_indptr(n_cols + 1); + cudaMemcpy(h_indptr.data(), csc_indptr, (n_cols + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + // Find max nnz across any sub-batch for buffer sizing + size_t max_nnz = 0; + for (int col = 0; col < n_cols; col += sub_batch_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); + if (nnz > max_nnz) max_nnz = nnz; + } + + // CUB temp size for max_nnz items + size_t cub_temp_bytes = 0; + if (max_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + + RmmPool pool; + struct StreamBuf { + float* keys_out; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* d_nz_scratch; // gmem-only + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].keys_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + int ptr_start = h_indptr[col]; + int ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = ptr_end - ptr_start; + + // Compute rebased segment offsets on GPU (avoids host pinned-buffer + // race) + { + int count = sb_cols + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel<<>>( + csc_indptr, buf.seg_offsets, col, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Sort only stored values (keys=data, vals=row_indices) + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, csc_data + ptr_start, buf.keys_out, + csc_indices + ptr_start, buf.vals_out, batch_nnz, sb_cols, + buf.seg_offsets, buf.seg_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + // Sparse rank kernel (handles implicit zeros analytically) + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.seg_offsets, group_codes, + group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // Scatter results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse ovr streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Sparse-aware CSR OVR streaming (partial CSR→CSC transpose per sub-batch) +// ============================================================================ + +/** + * Sparse-aware OVR streaming pipeline for GPU CSR data. + * + * Phase 0: One histogram kernel counts nnz per column. D2H + host prefix sums + * give exact per-batch nnz and max_batch_nnz for buffer sizing. + * Phase 1: Allocate per-stream buffers sized to max_batch_nnz. + * Phase 2: For each sub-batch: scatter CSR→CSC (partial transpose via + * atomics) → CUB sort only nonzeros → sparse rank kernel. + * + * Compared to the dense CSR path, sort work drops by ~1/sparsity. + */ +static void ovr_sparse_csr_streaming_impl( + const float* csr_data, const int* csr_indices, const int* csr_indptr, + const int* group_codes, const double* group_sizes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + // ---- Phase 0: Planning — count nnz per column via histogram ---- + RmmPool pool; + int* d_col_counts = pool.alloc(n_cols); + cudaMemset(d_col_counts, 0, n_cols * sizeof(int)); + { + int blocks = (n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + csr_col_histogram_kernel<<>>( + csr_indices, csr_indptr, d_col_counts, n_rows, n_cols); + CUDA_CHECK_LAST_ERROR(csr_col_histogram_kernel); + } + std::vector h_col_counts(n_cols); + cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(int), + cudaMemcpyDeviceToHost); + + // Per-batch prefix sums on host + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + size_t max_batch_nnz = 0; + + // Flat array: n_batches × (sub_batch_cols + 1) offsets + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + std::vector h_batch_nnz(n_batches); + + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb_cols = std::min(sub_batch_cols, n_cols - col_start); + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + off[0] = 0; + for (int i = 0; i < sb_cols; i++) + off[i + 1] = off[i] + h_col_counts[col_start + i]; + h_batch_nnz[b] = (size_t)off[sb_cols]; + if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; + } + + // Upload all batch offsets to GPU in one shot (~20 KB) + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // ---- Phase 1: Allocate per-stream buffers ---- + size_t cub_temp_bytes = 0; + if (max_batch_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_batch_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + int n_streams = N_STREAMS; + if (n_batches < n_streams) n_streams = n_batches; + + // CSR path needs 4 sort arrays per stream (scatter intermediates + + // CUB output). Fit stream count to available GPU memory. + size_t per_stream_bytes = + max_batch_nnz * (2 * sizeof(float) + 2 * sizeof(int)) + + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + + (size_t)n_groups * sub_batch_cols * sizeof(double) + + sub_batch_cols * sizeof(double); + + size_t free_mem = 0, total_mem = 0; + cudaMemGetInfo(&free_mem, &total_mem); + constexpr double MEM_BUDGET_FRAC = 0.8; + size_t budget = (size_t)(free_mem * MEM_BUDGET_FRAC); + while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + n_streams--; + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + int scatter_blocks = (n_rows + tpb - 1) / tpb; + + struct StreamBuf { + int* col_offsets; // [sub_batch_cols + 1] CSC-style offsets + int* write_pos; // [sub_batch_cols] atomic write counters + float* csc_vals; // [max_batch_nnz] transposed values + int* csc_row_idx; // [max_batch_nnz] transposed row indices + float* keys_out; // [max_batch_nnz] CUB sort output + int* vals_out; // [max_batch_nnz] CUB sort output + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* d_nz_scratch; // gmem-only + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].col_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].write_pos = pool.alloc(sub_batch_cols); + bufs[s].csc_vals = pool.alloc(max_batch_nnz); + bufs[s].csc_row_idx = pool.alloc(max_batch_nnz); + bufs[s].keys_out = pool.alloc(max_batch_nnz); + bufs[s].vals_out = pool.alloc(max_batch_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + // ---- Phase 2: Stream loop ---- + int col = 0; + for (int b = 0; b < n_batches; b++) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = b % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + int batch_nnz = (int)h_batch_nnz[b]; + + // D2D copy pre-computed col_offsets for this batch + int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // Initialize write_pos = col_offsets[0..sb_cols-1] (same D2D source) + cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + // Scatter CSR → CSC layout for this sub-batch + csr_scatter_to_csc_kernel<<>>( + csr_data, csr_indices, csr_indptr, buf.write_pos, buf.csc_vals, + buf.csc_row_idx, n_rows, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + + // CUB sort only the nonzeros + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.csc_vals, buf.keys_out, buf.csc_row_idx, + buf.vals_out, batch_nnz, sb_cols, buf.col_offsets, + buf.col_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // Sparse rank kernel (handles implicit zeros analytically) + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.col_offsets, group_codes, + group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // Scatter results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse CSR ovr streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_utils/_csr_to_csc.py b/src/rapids_singlecell/_utils/_csr_to_csc.py deleted file mode 100644 index 5f4700b0..00000000 --- a/src/rapids_singlecell/_utils/_csr_to_csc.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Fast parallel CSR to CSC conversion using Numba.""" - -from __future__ import annotations - -import numpy as np -import scipy.sparse as sp -from numba import get_num_threads, njit, prange - - -@njit(parallel=True, boundscheck=False) -def _csr_to_csc_kernel(csr_data, csr_indices, csr_indptr, n_cols): - """ - Numba kernel for parallel CSR to CSC conversion. - - Uses a tiled approach with parallel histogram + scatter phases, - targeting ~256 MB per block buffer for L3 cache efficiency. - """ - num_threads = get_num_threads() - nnz = len(csr_data) - n_rows = len(csr_indptr) - 1 - - # Allocate output arrays (int64 for matrices > 2GB) - csc_data = np.empty(nnz, dtype=csr_data.dtype) - csc_indices = np.empty(nnz, dtype=np.int64) - csc_indptr = np.empty(n_cols + 1, dtype=np.int64) - csc_indptr[0] = 0 - - # Block size targeting 256 MB per block buffer (L3 cache target) - TARGET_MEM_BYTES = 256 * 1024 * 1024 - MIN_BLOCK_SIZE = 1000 - block_size = TARGET_MEM_BYTES // (num_threads * 8) - if block_size < MIN_BLOCK_SIZE: - block_size = MIN_BLOCK_SIZE - - # Workspace: threads x block_width - counts = np.zeros((num_threads, block_size), dtype=np.int64) - row_chunk_size = (n_rows + num_threads - 1) // num_threads - - current_global_offset = 0 - - # Tiled execution over column blocks - for col_start in range(0, n_cols, block_size): - col_end = min(col_start + block_size, n_cols) - current_block_width = col_end - col_start - - # 1. Zero counters for this block - counts[:, :current_block_width] = 0 - - # 2. Parallel histogram (count items per column) - for t in prange(num_threads): - r_start = t * row_chunk_size - r_end = min((t + 1) * row_chunk_size, n_rows) - for r in range(r_start, r_end): - for i in range(csr_indptr[r], csr_indptr[r + 1]): - c = csr_indices[i] - if c >= col_start and c < col_end: - counts[t, c - col_start] += 1 - - # 3. Compute write offsets (sequential, fast) - for c in range(current_block_width): - total_in_col = 0 - for t in range(num_threads): - count = counts[t, c] - counts[t, c] = current_global_offset + total_in_col - total_in_col += count - - current_global_offset += total_in_col - csc_indptr[col_start + c + 1] = current_global_offset - - # 4. Parallel scatter (write data) - for t in prange(num_threads): - r_start = t * row_chunk_size - r_end = min((t + 1) * row_chunk_size, n_rows) - for r in range(r_start, r_end): - for i in range(csr_indptr[r], csr_indptr[r + 1]): - c = csr_indices[i] - if c >= col_start and c < col_end: - local_c = c - col_start - dest = counts[t, local_c] - - csc_data[dest] = csr_data[i] - csc_indices[dest] = r - - counts[t, local_c] += 1 - - return csc_data, csc_indices, csc_indptr - - -def _fast_csr_to_csc(mat_csr: sp.csr_matrix) -> sp.csc_matrix: - """ - Convert a SciPy CSR matrix to CSC using parallel Numba kernel. - - Uses a tiled multi-threaded approach for better cache utilization - compared to scipy's default conversion. - - Parameters - ---------- - mat_csr - Input CSR matrix. - - Returns - ------- - CSC matrix with the same data. - """ - if not sp.issparse(mat_csr) or mat_csr.format != "csr": - raise TypeError("Input must be a SciPy CSR matrix") - - rows, cols = mat_csr.shape - - data, indices, indptr = _csr_to_csc_kernel( - mat_csr.data, - mat_csr.indices, - mat_csr.indptr, - cols, - ) - - return sp.csc_matrix((data, indices, indptr), shape=mat_csr.shape) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index 0b9753a3..75469fcb 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys +import warnings from functools import partial from typing import TYPE_CHECKING, Literal @@ -8,6 +9,7 @@ import pandas as pd from ._core import _RankGenes +from ._utils import NoTestGroupsError if TYPE_CHECKING: from collections.abc import Iterable @@ -42,6 +44,7 @@ def rank_genes_groups( pre_load: bool = False, n_bins: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, + skip_empty_groups: bool = False, **kwds, ) -> None: """ @@ -104,10 +107,10 @@ def rank_genes_groups( layer Key from `adata.layers` whose value will be used to perform tests on. chunk_size - Number of genes to process at once for `'wilcoxon'` and - `'wilcoxon_binned'`. Default is 128 for `'wilcoxon'`. For - `'wilcoxon_binned'` the default is sized dynamically based on - ``n_groups`` and ``n_bins`` to keep histogram memory stable. + Number of genes to process at once for memory-bounded Wilcoxon + paths. For `'wilcoxon_binned'`, the default is sized dynamically + based on ``n_groups`` and ``n_bins`` to keep histogram memory + stable. pre_load Pre-load the data into GPU memory. Used only for `'wilcoxon'`. n_bins @@ -121,6 +124,14 @@ def rank_genes_groups( ``'log1p'`` uses a fixed [0, 15] range suitable for most log1p-normalized data. ``'auto'`` computes the actual data range. Use this for z-scored or unnormalized data. + skip_empty_groups + If ``True``, silently drop groups with fewer than 2 cells (issuing + a ``RuntimeWarning``) instead of raising a ``ValueError``. Useful + when iterating over data subsets (e.g. per cell type) where some + categorical levels may be empty. If no test groups remain after + filtering (only the reference has >=2 cells), the call returns + without updating ``adata.uns``. The reference group is never + dropped — if it has <2 cells the call still fails. **kwds Additional arguments passed to the method. For `'logreg'`, these are passed to :class:`cuml.linear_model.LogisticRegression`. @@ -187,17 +198,29 @@ def rank_genes_groups( msg = f"mask_var has wrong shape: {mask_var_array.shape[0]} != {adata.n_vars}" raise ValueError(msg) - test_obj = _RankGenes( - adata, - groups, - groupby, - mask_var=mask_var_array, - reference=reference, - use_raw=use_raw, - layer=layer, - comp_pts=pts, - pre_load=pre_load, - ) + try: + test_obj = _RankGenes( + adata, + groups, + groupby, + mask_var=mask_var_array, + reference=reference, + use_raw=use_raw, + layer=layer, + comp_pts=pts, + pre_load=pre_load, + skip_empty_groups=skip_empty_groups, + ) + except NoTestGroupsError as e: + # skip_empty_groups=True contract: no test groups left → no-op. + # Do not write to adata.uns so downstream loops can detect the + # missing key and skip this subset. + warnings.warn( + f"rank_genes_groups: skipping — {e}", + RuntimeWarning, + stacklevel=2, + ) + return # Determine n_genes_user n_genes_user = n_genes @@ -217,11 +240,21 @@ def rank_genes_groups( **kwds, ) - # Build output - test_obj.stats.columns = test_obj.stats.columns.swaplevel() - + # Use a U-width tight to the actual gene names rather than the scanpy + # default of U50. For a 1948-group × 18k-gene workload this cuts the + # names structured array from ~7 GB → ~3 GB. Must match the width that + # compute_statistics used when converting var_names (see _core.py), so + # the final stack → structured-array view is a pure memcpy. + _vn = np.asarray(test_obj.var_names) + if _vn.dtype.kind == "U": + max_name_len = _vn.dtype.itemsize // 4 + elif len(_vn): + max_name_len = max(len(str(n)) for n in _vn) + else: + max_name_len = 50 + names_dtype = f"U{max(max_name_len, 1)}" dtypes = { - "names": "U50", + "names": names_dtype, "scores": "float32", "logfoldchanges": "float32", "pvals": "float64", @@ -252,13 +285,47 @@ def rank_genes_groups( if method == "wilcoxon": adata.uns[key_added]["params"]["tie_correct"] = tie_correct - for col in test_obj.stats.columns.levels[0]: - if col in dtypes: - adata.uns[key_added][col] = test_obj.stats[col].to_records( - index=False, column_dtypes=dtypes[col] + # Assemble scanpy-compatible structured arrays directly from per-group + # arrays, without going through a wide pandas DataFrame + to_records — + # that pipeline was ~4 s of pure Python overhead on workloads with + # thousands of groups. + group_names = test_obj.group_names + if group_names: + for stat, per_group_arrays in test_obj.results.items(): + if per_group_arrays is None or stat not in dtypes: + continue + adata.uns[key_added][stat] = _build_structured( + group_names, per_group_arrays, dtypes[stat] ) +def _build_structured( + group_names: list[str], + per_group_arrays: list[np.ndarray], + field_dtype: str, +) -> np.ndarray: + """Build a scanpy-style structured array with one field per group. + + Equivalent to assigning ``sa[name] = arr`` in a loop, but ~20× faster + on wide workloads (1000+ groups): fills a contiguous 2-D (n, n_groups) + buffer row-by-row and reinterprets it as a structured view, so the + bulk write pattern matches the structured-array memory layout. + """ + n = per_group_arrays[0].shape[0] + # Stack directly into the target dtype — np.stack with `dtype=` avoids + # the separate astype pass. axis=1 produces (n, n_groups) C-contig + # exactly matching the structured memory layout below, so .view() is + # zero-copy. Use 'unsafe' casting for string targets (object→U50). + if np.dtype(field_dtype).kind == "U": + stacked = np.stack( + per_group_arrays, axis=1, casting="unsafe", dtype=field_dtype + ) + else: + stacked = np.stack(per_group_arrays, axis=1, dtype=field_dtype) + dtype = np.dtype([(name, field_dtype) for name in group_names]) + return stacked.view(dtype).reshape(n) + + if TYPE_CHECKING: from warnings import deprecated else: diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index c65bbf7c..ce663b80 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -1,18 +1,28 @@ from __future__ import annotations +from concurrent.futures import ThreadPoolExecutor +from os import cpu_count from typing import TYPE_CHECKING, Literal, assert_never import cupy as cp import numpy as np import pandas as pd -from statsmodels.stats.multitest import multipletests from rapids_singlecell._compat import DaskArray from rapids_singlecell.get import X_to_GPU from rapids_singlecell.get._aggregated import Aggregate from rapids_singlecell.preprocessing._utils import _check_gpu_X -from ._utils import EPS, _select_groups, _select_top_n +from ._utils import ( + EPS, + _benjamini_hochberg, + _select_groups, + _select_top_n, +) + +POSTPROCESS_PARALLEL_GROUPS = 256 +POSTPROCESS_PARALLEL_GENES = 1024 +POSTPROCESS_MAX_WORKERS = 8 if TYPE_CHECKING: from collections.abc import Iterable @@ -38,6 +48,7 @@ def __init__( layer: str | None = None, comp_pts: bool = False, pre_load: bool = False, + skip_empty_groups: bool = False, ) -> None: # Handle groups parameter if groups == "all" or groups is None: @@ -63,7 +74,10 @@ def __init__( raise ValueError(msg) self.groups_order, self.group_codes, self.group_sizes = _select_groups( - self.labels, selected + self.labels, + selected, + skip_empty_groups=skip_empty_groups, + reference=reference, ) # Get data matrix @@ -114,9 +128,19 @@ def __init__( self.vars_rest: np.ndarray | None = None self.pts_rest: np.ndarray | None = None - self.stats: pd.DataFrame | None = None + # Per-stat × per-group arrays. results[stat] is either None (stat not + # computed) or a list of 1-D numpy arrays in group_names order. This + # replaces the old DataFrame-based pipeline, which burned ~4 s per call + # on wide workloads (1000+ groups) in pandas DataFrame + to_records. + self.results: dict[str, list[np.ndarray] | None] = { + "names": None, + "scores": None, + "pvals": None, + "pvals_adj": None, + "logfoldchanges": None, + } + self.group_names: list[str] = [] self._compute_stats_in_chunks: bool = False - self._ref_chunk_computed: set[int] = set() def _init_stats_arrays(self, n_genes: int) -> None: """Pre-allocate stats arrays before chunk loop.""" @@ -190,16 +214,18 @@ def _basic_stats(self) -> None: # Compute rest statistics if reference='rest' if self.ireference is None: - n_rest = n.sum() - n - means_rest = (sums.sum(axis=0) - sums) / n_rest - rest_ss = (sq_sums.sum(axis=0) - sq_sums) - n_rest * means_rest**2 + n_rest = cp.float64(self.X.shape[0]) - n + total_sums = result["sum"].sum(axis=0, keepdims=True) + total_sq_sums = result["sq_sum"].sum(axis=0, keepdims=True) + means_rest = (total_sums - sums) / n_rest + rest_ss = (total_sq_sums - sq_sums) - n_rest * means_rest**2 vars_rest = cp.maximum(rest_ss / cp.maximum(n_rest - 1, 1), 0) self.means_rest = cp.asnumpy(means_rest) self.vars_rest = cp.asnumpy(vars_rest) if self.comp_pts: - total_count = (pts * n).sum(axis=0) + total_count = result["count_nonzero"].sum(axis=0, keepdims=True) self.pts_rest = cp.asnumpy((total_count - pts * n) / n_rest) else: self.pts_rest = None @@ -213,104 +239,6 @@ def _basic_stats(self) -> None: self.vars = cp.asnumpy(vars_) self.pts = cp.asnumpy(pts) if pts is not None else None - def _accumulate_chunk_stats_vs_rest( - self, - block: cp.ndarray, - start: int, - stop: int, - *, - group_matrix: cp.ndarray, - group_sizes_dev: cp.ndarray, - n_cells: int, - ) -> None: - """Compute and store stats for one gene chunk (vs rest mode).""" - if not self._compute_stats_in_chunks: - return # Stats already computed via Aggregate - - rest_sizes = n_cells - group_sizes_dev - - # Group sums and sum of squares - group_sums = group_matrix.T @ block - group_sum_sq = group_matrix.T @ (block**2) - - # Means - chunk_means = group_sums / group_sizes_dev[:, None] - self.means[:, start:stop] = cp.asnumpy(chunk_means) - - # Variances (with Bessel correction) - chunk_vars = group_sum_sq / group_sizes_dev[:, None] - chunk_means**2 - chunk_vars *= group_sizes_dev[:, None] / (group_sizes_dev[:, None] - 1) - self.vars[:, start:stop] = cp.asnumpy(chunk_vars) - - # Pts (fraction expressing) - if self.comp_pts: - group_nnz = group_matrix.T @ (block != 0).astype(cp.float64) - self.pts[:, start:stop] = cp.asnumpy(group_nnz / group_sizes_dev[:, None]) - - # Rest statistics - if self.ireference is None: - total_sum = block.sum(axis=0) - total_sum_sq = (block**2).sum(axis=0) - - rest_sums = total_sum[None, :] - group_sums - rest_means = rest_sums / rest_sizes[:, None] - self.means_rest[:, start:stop] = cp.asnumpy(rest_means) - - rest_sum_sq = total_sum_sq[None, :] - group_sum_sq - rest_vars = rest_sum_sq / rest_sizes[:, None] - rest_means**2 - rest_vars *= rest_sizes[:, None] / (rest_sizes[:, None] - 1) - self.vars_rest[:, start:stop] = cp.asnumpy(rest_vars) - - if self.comp_pts: - total_nnz = (block != 0).sum(axis=0) - rest_nnz = total_nnz[None, :] - group_nnz - self.pts_rest[:, start:stop] = cp.asnumpy( - rest_nnz / rest_sizes[:, None] - ) - - def _accumulate_chunk_stats_with_ref( - self, - block: cp.ndarray, - start: int, - stop: int, - *, - group_index: int, - group_mask_gpu: cp.ndarray, - n_group: int, - n_ref: int, - ) -> None: - """Compute and store stats for one gene chunk (with reference mode).""" - if not self._compute_stats_in_chunks: - return # Stats already computed via Aggregate - - # Group stats - group_data = block[group_mask_gpu] - group_mean = group_data.mean(axis=0) - self.means[group_index, start:stop] = cp.asnumpy(group_mean) - - if n_group > 1: - group_var = group_data.var(axis=0, ddof=1) - self.vars[group_index, start:stop] = cp.asnumpy(group_var) - - if self.comp_pts: - group_nnz = (group_data != 0).sum(axis=0) - self.pts[group_index, start:stop] = cp.asnumpy(group_nnz / n_group) - - # Reference stats (only compute once, on first non-reference group) - if start not in self._ref_chunk_computed: - self._ref_chunk_computed.add(start) - ref_data = block[~group_mask_gpu] - ref_mean = ref_data.mean(axis=0) - self.means[self.ireference, start:stop] = cp.asnumpy(ref_mean) - - if n_ref > 1: - ref_var = ref_data.var(axis=0, ddof=1) - self.vars[self.ireference, start:stop] = cp.asnumpy(ref_var) - - if self.comp_pts: - ref_nnz = (ref_data != 0).sum(axis=0) - self.pts[self.ireference, start:stop] = cp.asnumpy(ref_nnz / n_ref) - def t_test( self, method: Literal["t-test", "t-test_overestim_var"] ) -> list[tuple[int, NDArray, NDArray]]: @@ -411,56 +339,136 @@ def compute_statistics( n_genes = self.X.shape[1] - # Collect all stats data first to avoid DataFrame fragmentation - stats_data: dict[tuple[str, str], np.ndarray] = {} - - for group_index, scores, pvals in test_results: - group_name = str(self.groups_order[group_index]) + if not test_results: + self.group_names = [] + return + group_indices = np.array([gi for gi, _, _ in test_results], dtype=np.int64) + self.group_names = [str(self.groups_order[gi]) for gi in group_indices] + has_lfc = self.means is not None + + # Vectorised log-fold-change across all test groups — avoids 1948 + # individual numpy ops in the hot path (was ~1.5 s of Python + # overhead on wide workloads). Cast to the output dtype here so + # _build_structured never needs a full-array astype pass. + lfc_all: np.ndarray | None = None + if has_lfc: + mean_groups = self.means[group_indices] + if self.ireference is None: + mean_rests = self.means_rest[group_indices] + else: + mean_rests = self.means[self.ireference][None, :] + foldchanges = (self.expm1_func(mean_groups) + EPS) / ( + self.expm1_func(mean_rests) + EPS + ) + lfc_all = np.log2(foldchanges).astype(np.float32, copy=False) + + names_list: list[np.ndarray] = [] + scores_list: list[np.ndarray] = [] + pvals_list: list[np.ndarray] = [] + pvals_adj_list: list[np.ndarray] = [] + lfc_list: list[np.ndarray] = [] + has_pvals = False + # Pre-convert var_names to fixed-width unicode ONCE. Without this, + # per-group indexing returns object arrays, and np.stack(..., dtype='U') + # ends up doing ~35 M object→string conversions inside the hot loop — + # that was ~1.5 s on the 1948-group / 18k-gene workload. Using the + # target width directly turns the final stack into a pure memcpy. + _vn = np.asarray(self.var_names) + if _vn.dtype.kind == "U": + var_names_arr = _vn + else: + max_len = max((len(str(n)) for n in _vn), default=1) + var_names_arr = _vn.astype(f"U{max_len}") + + def _process_result( + ti: int, + ) -> tuple[ + int, + np.ndarray | None, + np.ndarray, + np.ndarray | None, + np.ndarray | None, + np.ndarray | None, + ]: + _, scores, pvals = test_results[ti] if n_genes_user is not None: scores_sort = np.abs(scores) if rankby_abs else scores global_indices = _select_top_n(scores_sort, n_genes_user) + names = var_names_arr[global_indices] else: global_indices = slice(None) + names = None - if n_genes_user is not None: - stats_data[group_name, "names"] = np.asarray(self.var_names)[ - global_indices - ] - - stats_data[group_name, "scores"] = scores[global_indices] + scores_out = scores[global_indices] + pvals_out = None + pvals_adj_out = None if pvals is not None: - stats_data[group_name, "pvals"] = pvals[global_indices] + pvals_out = pvals[global_indices] if corr_method == "benjamini-hochberg": - pvals_clean = np.array(pvals, copy=True) - pvals_clean[np.isnan(pvals_clean)] = 1.0 - _, pvals_adj, _, _ = multipletests( - pvals_clean, alpha=0.05, method="fdr_bh" - ) + pvals_adj = _benjamini_hochberg(pvals) elif corr_method == "bonferroni": pvals_adj = np.minimum(pvals * n_genes, 1.0) - stats_data[group_name, "pvals_adj"] = pvals_adj[global_indices] - - # Compute logfoldchanges - if self.means is not None: - mean_group = self.means[group_index] - if self.ireference is None: - mean_rest = self.means_rest[group_index] - else: - mean_rest = self.means[self.ireference] - foldchanges = (self.expm1_func(mean_group) + EPS) / ( - self.expm1_func(mean_rest) + EPS - ) - stats_data[group_name, "logfoldchanges"] = np.log2( - foldchanges[global_indices] - ) - - # Create DataFrame all at once to avoid fragmentation - if stats_data: - self.stats = pd.DataFrame(stats_data) - self.stats.columns = pd.MultiIndex.from_tuples(self.stats.columns) - if n_genes_user is None: - self.stats.index = self.var_names + pvals_adj_out = pvals_adj[global_indices] + + lfc_out = None + if lfc_all is not None: + lfc_out = lfc_all[ti][global_indices] + + return ti, names, scores_out, pvals_out, pvals_adj_out, lfc_out + + def _process_range( + start: int, stop: int + ) -> list[ + tuple[ + int, + np.ndarray | None, + np.ndarray, + np.ndarray | None, + np.ndarray | None, + np.ndarray | None, + ] + ]: + return [_process_result(ti) for ti in range(start, stop)] + + n_results = len(test_results) + use_parallel_post = ( + n_results >= POSTPROCESS_PARALLEL_GROUPS + and n_genes >= POSTPROCESS_PARALLEL_GENES + ) + if use_parallel_post: + workers = min(POSTPROCESS_MAX_WORKERS, cpu_count() or 1, n_results) + chunk = (n_results + workers - 1) // workers + ranges = [ + (start, min(start + chunk, n_results)) + for start in range(0, n_results, chunk) + ] + with ThreadPoolExecutor(max_workers=workers) as executor: + processed_chunks = executor.map(lambda r: _process_range(*r), ranges) + processed = [ + item for chunk_out in processed_chunks for item in chunk_out + ] else: - self.stats = None + processed = _process_range(0, n_results) + + for _, names, scores_out, pvals_out, pvals_adj_out, lfc_out in processed: + if names is not None: + names_list.append(names) + scores_list.append(scores_out) + if pvals_out is not None and pvals_adj_out is not None: + has_pvals = True + pvals_list.append(pvals_out) + pvals_adj_list.append(pvals_adj_out) + if lfc_out is not None: + lfc_list.append(lfc_out) + + if self.group_names: + self.results["scores"] = scores_list + if n_genes_user is not None: + self.results["names"] = names_list + if has_pvals: + self.results["pvals"] = pvals_list + self.results["pvals_adj"] = pvals_adj_list + if lfc_all is not None: + self.results["logfoldchanges"] = lfc_list diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index c4f2c601..c46ec26d 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -1,26 +1,30 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING -import cupy as cp -import cupyx.scipy.sparse as cpsp import numpy as np -import scipy.sparse as sp - -from rapids_singlecell.preprocessing._utils import _sparse_to_dense if TYPE_CHECKING: import pandas as pd from numpy.typing import NDArray EPS = 1e-9 -WARP_SIZE = 32 -MAX_THREADS_PER_BLOCK = 512 + + +class NoTestGroupsError(ValueError): + """Raised when skip_empty_groups=True and no test groups remain after + filtering. The public ``rank_genes_groups`` catches this and returns + quietly (after emitting a ``RuntimeWarning``) so callers can iterate + over data subsets without wrapping each call in try/except.""" def _select_groups( labels: pd.Series, selected: list | None, + *, + skip_empty_groups: bool = False, + reference: str | None = None, ) -> tuple[NDArray, NDArray[np.int32], NDArray[np.int64]]: """Build integer group codes from a categorical Series. @@ -31,6 +35,17 @@ def _select_groups( selected Group names to keep, or ``None`` for all groups. Must already include the reference group if applicable. + skip_empty_groups + If ``True``, drop groups with fewer than 2 cells instead of raising, + emitting a ``RuntimeWarning`` that lists the dropped groups. Useful + when iterating over data subsets (e.g. per cell type) where some + categorical levels have no cells. The reference group is never + silently dropped — if it has <2 cells, a ``ValueError`` is raised. + reference + Name of the reference group (``None`` or ``"rest"`` if there is + no fixed reference). Only used when ``skip_empty_groups`` is + ``True`` to validate that the reference is not the one being + dropped. Returns ------- @@ -51,27 +66,65 @@ def _select_groups( cat_order = {str(c): i for i, c in enumerate(all_categories)} selected.sort(key=lambda x: cat_order.get(str(x), len(all_categories))) + # First pass: compute sizes for the currently-selected groups so we can + # optionally drop empty/singleton ones before assigning final codes. + orig_codes_all = labels.cat.codes.to_numpy() + + def _compute_codes_and_sizes( + sel: list, + ) -> tuple[NDArray[np.int32], NDArray[np.int64]]: + n = len(sel) + str_to_sel = {str(name): idx for idx, name in enumerate(sel)} + lookup = np.full(len(all_categories) + 1, n, dtype=np.int32) + for cat_idx, cat_name in enumerate(all_categories): + sel_idx = str_to_sel.get(str(cat_name)) + if sel_idx is not None: + lookup[cat_idx] = sel_idx + codes = lookup[orig_codes_all] + sizes = np.bincount(codes, minlength=n + 1)[:n].astype(np.int64) + return codes, sizes + + _, preview_sizes = _compute_codes_and_sizes(selected) + + if skip_empty_groups: + ref_str = str(reference) if reference not in (None, "rest") else None + empty_idx = [i for i, s in enumerate(preview_sizes) if s < 2] + if empty_idx: + empty_names = [str(selected[i]) for i in empty_idx] + if ref_str is not None and ref_str in empty_names: + msg = ( + f"Reference group {ref_str!r} has <2 cells; cannot run " + "with skip_empty_groups=True." + ) + raise ValueError(msg) + warnings.warn( + f"Dropping {len(empty_names)} group(s) with <2 cells: " + f"{', '.join(empty_names)}", + RuntimeWarning, + stacklevel=3, + ) + selected = [g for i, g in enumerate(selected) if i not in set(empty_idx)] + + # Need at least one test group once the reference is excluded. + ref_str = str(reference) if reference not in (None, "rest") else None + n_test = sum(1 for g in selected if str(g) != ref_str) + if n_test == 0: + msg = ( + "No test groups with >=2 cells remain after filtering " + "(only the reference has enough cells)." + ) + raise NoTestGroupsError(msg) + n_groups = len(selected) groups_order = np.array(selected) - # Map original category index → selected group index - str_to_sel = {str(name): idx for idx, name in enumerate(selected)} - orig_to_sel: dict[int, int] = {} - for cat_idx, cat_name in enumerate(all_categories): - sel_idx = str_to_sel.get(str(cat_name)) - if sel_idx is not None: - orig_to_sel[cat_idx] = sel_idx - - orig_codes = labels.cat.codes.to_numpy() - group_codes = np.full(len(orig_codes), n_groups, dtype=np.int32) - for orig_idx, sel_idx in orig_to_sel.items(): - group_codes[orig_codes == orig_idx] = sel_idx + if n_groups == 0: + msg = "No groups with >=2 cells remain after filtering." + raise ValueError(msg) - group_sizes = np.bincount(group_codes, minlength=n_groups + 1)[:n_groups].astype( - np.int64 - ) + group_codes, group_sizes = _compute_codes_and_sizes(selected) - # Validate singlet groups + # Validate singlet groups (only triggers when skip_empty_groups=False). invalid_groups = {str(selected[i]) for i in range(n_groups) if group_sizes[i] < 2} if invalid_groups: msg = ( @@ -83,63 +136,31 @@ def _select_groups( return groups_order, group_codes, group_sizes -def _round_up_to_warp(n: int) -> int: - """Round up to nearest multiple of WARP_SIZE, capped at MAX_THREADS_PER_BLOCK.""" - return min(MAX_THREADS_PER_BLOCK, ((n + WARP_SIZE - 1) // WARP_SIZE) * WARP_SIZE) - - def _select_top_n(scores: NDArray, n_top: int) -> NDArray: """Select indices of top n scores. Uses argpartition + argsort for O(n + k log k) complexity where k = n_top. This is faster than full sorting when k << n. """ - n_from = scores.shape[0] - reference_indices = np.arange(n_from, dtype=int) + if n_top >= scores.shape[0]: + return np.argsort(scores)[::-1] partition = np.argpartition(scores, -n_top)[-n_top:] - partial_indices = np.argsort(scores[partition])[::-1] - global_indices = reference_indices[partition][partial_indices] - return global_indices + return partition[np.argsort(scores[partition])[::-1]] -def _choose_chunk_size(requested: int | None) -> int: - """Choose chunk size for gene processing.""" - if requested is not None: - return int(requested) - return 128 +def _benjamini_hochberg(pvals: NDArray) -> NDArray: + """Adjust p-values with the Benjamini-Hochberg FDR procedure.""" + pvals_clean = np.array(pvals, copy=True) + pvals_clean[np.isnan(pvals_clean)] = 1.0 + n_tests = pvals_clean.size + order = np.argsort(pvals_clean) + ordered = pvals_clean[order] + ranks = np.arange(1, n_tests + 1, dtype=ordered.dtype) / n_tests + adjusted = ordered / ranks + np.minimum.accumulate(adjusted[::-1], out=adjusted[::-1]) + np.minimum(adjusted, 1.0, out=adjusted) -def _csc_columns_to_gpu(X_csc, start: int, stop: int, n_rows: int) -> cp.ndarray: - """ - Extract columns from a CSC matrix via direct indptr pointer slicing. - - Works for both scipy and CuPy CSC matrices. Much faster than - ``X[:, start:stop]`` which rebuilds index arrays internally. - """ - s_ptr = int(X_csc.indptr[start]) - e_ptr = int(X_csc.indptr[stop]) - chunk_data = cp.asarray(X_csc.data[s_ptr:e_ptr]) - chunk_indices = cp.asarray(X_csc.indices[s_ptr:e_ptr]) - chunk_indptr = cp.asarray(X_csc.indptr[start : stop + 1] - s_ptr) - csc_chunk = cpsp.csc_matrix( - (chunk_data, chunk_indices, chunk_indptr), shape=(n_rows, stop - start) - ) - return _sparse_to_dense(csc_chunk, order="F").astype(cp.float64) - - -def _get_column_block(X, start: int, stop: int) -> cp.ndarray: - """Extract a column block as a dense F-order float64 CuPy array.""" - match X: - case sp.csc_matrix() | sp.csc_array(): - return _csc_columns_to_gpu(X, start, stop, X.shape[0]) - case sp.spmatrix() | sp.sparray(): - chunk = cpsp.csc_matrix(X[:, start:stop].tocsc()) - return _sparse_to_dense(chunk, order="F").astype(cp.float64) - case cpsp.csc_matrix(): - return _csc_columns_to_gpu(X, start, stop, X.shape[0]) - case cpsp.spmatrix(): - return _sparse_to_dense(X[:, start:stop], order="F").astype(cp.float64) - case np.ndarray() | cp.ndarray(): - return cp.asarray(X[:, start:stop], dtype=cp.float64, order="F") - case _: - raise ValueError(f"Unsupported matrix type: {type(X)}") + out = np.empty_like(adjusted) + out[order] = adjusted + return out diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index c14c760d..5fcee8d2 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -4,14 +4,12 @@ from typing import TYPE_CHECKING import cupy as cp +import cupyx.scipy.sparse as cpsp import cupyx.scipy.special as cupyx_special import numpy as np import scipy.sparse as sp -from rapids_singlecell._cuda import _wilcoxon_cuda as _wc -from rapids_singlecell._utils._csr_to_csc import _fast_csr_to_csc - -from ._utils import _choose_chunk_size, _get_column_block +from rapids_singlecell._cuda import _wilcoxon_ovo_cuda as _wc if TYPE_CHECKING: from numpy.typing import NDArray @@ -19,71 +17,424 @@ from ._core import _RankGenes MIN_GROUP_SIZE_WARNING = 25 +STREAMING_SUB_BATCH = 64 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _resolve_chunk_size(chunk_size: int | None, n_total_genes: int) -> int: + if chunk_size is None: + return max(1, n_total_genes) + chunk_width = int(chunk_size) + if chunk_width <= 0: + msg = "`chunk_size` must be a positive integer." + raise ValueError(msg) + return min(chunk_width, max(1, n_total_genes)) -def _average_ranks( - matrix: cp.ndarray, *, return_sorted: bool = False -) -> cp.ndarray | tuple[cp.ndarray, cp.ndarray]: +def _fill_basic_stats_from_accumulators( + rg: _RankGenes, + group_sums: cp.ndarray, + group_sq_sums: cp.ndarray, + group_nnz: cp.ndarray, + group_sizes: np.ndarray, + *, + n_cells: int, + compute_vars: bool = True, + total_sums: cp.ndarray | None = None, + total_sq_sums: cp.ndarray | None = None, + total_nnz: cp.ndarray | None = None, +) -> None: + """Populate rg.means/vars/pts (+ *_rest) from streamed accumulators. + + Mirrors the Aggregate-based path in :meth:`_RankGenes._basic_stats` + but consumes per-group sums/sum-of-squares/nnz that the host-streaming + kernels write directly into caller-provided CuPy buffers. The math + runs on GPU and only the derived (means/vars/pts) arrays are + transferred to host, so the full matrix never round-trips. """ - Compute average ranks for each column using GPU kernel. + n = cp.asarray(group_sizes, dtype=cp.float64)[:, None] + means = group_sums / n - Uses scipy.stats.rankdata 'average' method: ties get the average - of the ranks they would span. + rg.means = cp.asnumpy(means) + if compute_vars: + group_ss = group_sq_sums - n * means**2 + vars_ = cp.maximum(group_ss / cp.maximum(n - 1, 1), 0) + rg.vars = cp.asnumpy(vars_) + else: + rg.vars = np.zeros_like(rg.means) + rg.pts = cp.asnumpy(group_nnz / n) if rg.comp_pts else None - Parameters - ---------- - matrix - Input matrix (n_rows, n_cols) - return_sorted - If True, also return sorted values (useful for tie correction) + if rg.ireference is None: + n_rest = cp.float64(n_cells) - n + if total_sums is None: + total_sums = group_sums.sum(axis=0, keepdims=True) + rest_sums = total_sums - group_sums + rest_means = rest_sums / n_rest + rg.means_rest = cp.asnumpy(rest_means) + if compute_vars: + if total_sq_sums is None: + total_sq_sums = group_sq_sums.sum(axis=0, keepdims=True) + rest_ss = (total_sq_sums - group_sq_sums) - n_rest * rest_means**2 + rg.vars_rest = cp.asnumpy( + cp.maximum(rest_ss / cp.maximum(n_rest - 1, 1), 0) + ) + else: + rg.vars_rest = np.zeros_like(rg.means_rest) + if rg.comp_pts: + if total_nnz is None: + total_nnz = group_nnz.sum(axis=0, keepdims=True) + rg.pts_rest = cp.asnumpy((total_nnz - group_nnz) / n_rest) + else: + rg.pts_rest = None + else: + rg.means_rest = None + rg.vars_rest = None + rg.pts_rest = None + + rg._compute_stats_in_chunks = False - Returns - ------- - ranks or (ranks, sorted_vals) + +def _fill_ovo_stats_from_accumulators( + rg: _RankGenes, + group_sums_slots: cp.ndarray, + group_sq_sums_slots: cp.ndarray, + group_nnz_slots: cp.ndarray, + *, + group_sizes: NDArray, + test_group_indices: list[int], + n_ref: int, + compute_vars: bool = True, +) -> None: + """Populate rg.means/vars/pts from OVO stats slots. + + Slot ordering: 0..n_test-1 are test groups (in ``test_group_indices`` + order); slot n_test is the reference group. Stats arrays arrive as + CuPy buffers; the math runs on GPU and only the per-group rows are + transferred to host as they're assigned onto rg. """ - n_rows, n_cols = matrix.shape + n_test = len(test_group_indices) + n_genes = int(group_sums_slots.shape[1]) + n_groups = len(rg.groups_order) + + rg.means = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.vars = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.pts = np.zeros((n_groups, n_genes), dtype=np.float64) if rg.comp_pts else None + + def _fill(slot: int, size: int, gi: int) -> None: + if size <= 0: + return + sums = group_sums_slots[slot] + mean = sums / size + rg.means[gi] = cp.asnumpy(mean) + if compute_vars and size > 1: + sq = group_sq_sums_slots[slot] + ss = sq - size * mean**2 + rg.vars[gi] = cp.asnumpy(cp.maximum(ss / max(size - 1, 1), 0)) + if rg.comp_pts: + rg.pts[gi] = cp.asnumpy(group_nnz_slots[slot] / size) + + for i, gi in enumerate(test_group_indices): + _fill(i, int(group_sizes[gi]), gi) + _fill(n_test, int(n_ref), rg.ireference) + + rg.means_rest = None + rg.vars_rest = None + rg.pts_rest = None + rg._compute_stats_in_chunks = False + + +def _to_gpu_native(X, n_rows: int, n_cols: int): + """Move *X* to GPU, preserving its format (CSR / CSC / dense).""" + # Already on GPU + if isinstance(X, cp.ndarray): + return X + if cpsp.issparse(X): + return X + + # Host sparse → GPU sparse, same format. Wilcoxon kernels are native CSR + # or CSC only; do not hide a whole-matrix sparse format conversion here. + # Downcast indices to int32 on host before transfer (column indices + # always fit in int32; scipy may use int64 when nnz > 2^31). + if isinstance(X, sp.spmatrix | sp.sparray): + if X.format == "csc": + return cpsp.csc_matrix( + ( + cp.asarray(X.data), + cp.asarray(X.indices.astype(np.int32, copy=False)), + cp.asarray(X.indptr), + ), + shape=(n_rows, n_cols), + ) + if X.format != "csr": + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." + ) + csr = X + return cpsp.csr_matrix( + ( + cp.asarray(csr.data), + cp.asarray(csr.indices.astype(np.int32, copy=False)), + cp.asarray(csr.indptr), + ), + shape=(n_rows, n_cols), + ) + + # Host dense → GPU dense + if isinstance(X, np.ndarray): + return cp.asarray(X) + + raise TypeError(f"Unsupported matrix type: {type(X)}") + + +def _host_sparse_fn_and_arrays(module, base_name: str, X, *, support_idx64: bool): + """Select host sparse binding and dtype-normalized arrays.""" + is_f64 = X.data.dtype == np.float64 + is_idx64 = support_idx64 and X.indices.dtype == np.int64 + is_i64 = X.indptr.dtype == np.int64 + suffix = "" + if is_f64: + suffix += "_f64" + if is_idx64: + suffix += "_idx64" + if is_i64: + suffix += "_i64" + fn = getattr(module, base_name + suffix) + data_arr = X.data if is_f64 else X.data.astype(np.float32, copy=False) + indices_arr = X.indices if is_idx64 else X.indices.astype(np.int32, copy=False) + return fn, data_arr, indices_arr + + +def _column_totals_for_host_matrix( + X, *, compute_sq_sums: bool, compute_nnz: bool +) -> tuple[cp.ndarray, cp.ndarray | None, cp.ndarray | None]: + """Compute all-cell column totals without changing sparse format.""" + n_cols = X.shape[1] + + if isinstance(X, sp.spmatrix | sp.sparray): + data = np.asarray(X.data) + values = data.astype(np.float64, copy=False) + + if X.format == "csc": + indptr = np.asarray(X.indptr) + counts = np.diff(indptr) + nonempty = counts > 0 + starts = indptr[:-1][nonempty] - # Sort each column - sorter = cp.argsort(matrix, axis=0) - sorted_vals = cp.take_along_axis(matrix, sorter, axis=0) + sums = np.zeros(n_cols, dtype=np.float64) + if starts.size: + sums[nonempty] = np.add.reduceat(values, starts) - # Ensure F-order for kernel (columns contiguous in memory) - sorted_vals = cp.asfortranarray(sorted_vals) - sorter = cp.asfortranarray(sorter.astype(cp.int32)) + sq_sums = None + if compute_sq_sums: + sq_sums = np.zeros(n_cols, dtype=np.float64) + if starts.size: + sq_sums[nonempty] = np.add.reduceat(values * values, starts) - stream = cp.cuda.get_current_stream().ptr - _wc.average_rank( - sorted_vals, sorter, matrix, n_rows=n_rows, n_cols=n_cols, stream=stream + nnz = None + if compute_nnz: + nnz = np.zeros(n_cols, dtype=np.float64) + if starts.size: + nnz[nonempty] = np.add.reduceat( + (data != 0).astype(np.float64, copy=False), starts + ) + elif X.format == "csr": + indices = np.asarray(X.indices, dtype=np.intp) + sums = np.bincount(indices, weights=values, minlength=n_cols).astype( + np.float64, copy=False + ) + + sq_sums = ( + np.bincount(indices, weights=values * values, minlength=n_cols).astype( + np.float64, copy=False + ) + if compute_sq_sums + else None + ) + nnz = ( + np.bincount( + indices, + weights=(data != 0).astype(np.float64, copy=False), + minlength=n_cols, + ).astype(np.float64, copy=False) + if compute_nnz + else None + ) + else: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." + ) + elif isinstance(X, np.ndarray): + sums = np.asarray(X.sum(axis=0, dtype=np.float64), dtype=np.float64) + sq_sums = ( + np.asarray(np.square(X, dtype=np.float64).sum(axis=0), dtype=np.float64) + if compute_sq_sums + else None + ) + nnz = ( + np.asarray(np.count_nonzero(X, axis=0), dtype=np.float64) + if compute_nnz + else None + ) + else: + raise TypeError(f"Unsupported host matrix type: {type(X)}") + + total_sums = cp.asarray(sums.reshape(1, n_cols), dtype=cp.float64) + total_sq_sums = ( + cp.asarray(sq_sums.reshape(1, n_cols), dtype=cp.float64) + if sq_sums is not None + else None + ) + total_nnz = ( + cp.asarray(nnz.reshape(1, n_cols), dtype=cp.float64) + if nnz is not None + else None ) + return total_sums, total_sq_sums, total_nnz - if return_sorted: - return matrix, sorted_vals - return matrix +def _host_ovr_totals_if_needed( + X, + group_codes: np.ndarray, + n_groups: int, + *, + compute_sq_sums: bool, + compute_nnz: bool, +) -> tuple[cp.ndarray | None, cp.ndarray | None, cp.ndarray | None]: + if not np.any(group_codes == n_groups): + return None, None, None + return _column_totals_for_host_matrix( + X, compute_sq_sums=compute_sq_sums, compute_nnz=compute_nnz + ) + + +def _extract_dense_block( + X, + row_ids: cp.ndarray | None, + start: int, + stop: int, + *, + csr_arrays: tuple[cp.ndarray, cp.ndarray, cp.ndarray] | None = None, +) -> cp.ndarray: + """Extract ``X[row_ids, start:stop]`` as dense F-order on GPU. -def _tie_correction(sorted_vals: cp.ndarray) -> cp.ndarray: + CSR kernel path: outputs same dtype as CSR data (float32 or float64). + Other paths: preserve input dtype. """ - Compute tie correction factor for Wilcoxon test. + if csr_arrays is not None: + data, indices, indptr = csr_arrays + if row_ids is None: + n_target = int(indptr.shape[0] - 1) + row_ids = cp.arange(n_target, dtype=cp.int32) + n_target = row_ids.shape[0] + n_cols = stop - start + out = cp.zeros((n_target, n_cols), dtype=data.dtype, order="F") + if n_target > 0 and n_cols > 0: + stream = cp.cuda.get_current_stream().ptr + if data.dtype == cp.float32: + _wc.csr_extract_dense_f32( + data, + indices, + indptr, + row_ids, + out, + n_target=n_target, + col_start=start, + col_stop=stop, + stream=stream, + ) + else: + _wc.csr_extract_dense( + data, + indices, + indptr, + row_ids, + out, + n_target=n_target, + col_start=start, + col_stop=stop, + stream=stream, + ) + return out + + if isinstance(X, np.ndarray): + if row_ids is not None: + return cp.asarray(X[cp.asnumpy(row_ids), start:stop], order="F") + return cp.asarray(X[:, start:stop], order="F") + + if isinstance(X, cp.ndarray): + chunk = X[row_ids, start:stop] if row_ids is not None else X[:, start:stop] + return cp.asfortranarray(chunk) + + if isinstance(X, sp.spmatrix | sp.sparray): + if row_ids is not None: + idx = cp.asnumpy(row_ids) + chunk = X[idx][:, start:stop].toarray() + else: + chunk = X[:, start:stop].toarray() + return cp.asarray(chunk, order="F") + + if cpsp.issparse(X): + if row_ids is not None: + chunk = X[row_ids][:, start:stop].toarray() + else: + chunk = X[:, start:stop].toarray() + return cp.asfortranarray(chunk) + + raise TypeError(f"Unsupported matrix type: {type(X)}") - Takes pre-sorted values (column-wise) to avoid re-sorting. - Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) - where t is the count of tied values. + +def _segmented_sort_columns( + data: cp.ndarray, + offsets_host: np.ndarray, + n_rows: int, + n_cols: int, + n_groups: int, +) -> cp.ndarray: + """Sort each group segment within each column using CUB radix sort. + + Sorts in float32 for half the bandwidth. Returns float32 F-order. """ - n_rows, n_cols = sorted_vals.shape - correction = cp.ones(n_cols, dtype=cp.float64) + n_items = n_rows * n_cols + n_segments = n_cols * n_groups + + col_bases = np.arange(n_cols, dtype=np.int32) * n_rows + seg_starts = col_bases[:, None] + offsets_host[None, :n_groups] + seg_arr = np.empty(n_segments + 1, dtype=np.int32) + seg_arr[:n_segments] = seg_starts.ravel() + seg_arr[n_segments] = n_items + seg_offsets_gpu = cp.asarray(seg_arr) - if n_rows < 2: - return correction + temp_bytes = _wc.get_seg_sort_temp_bytes(n_items=n_items, n_segments=n_segments) + cub_temp = cp.empty(temp_bytes, dtype=cp.uint8) - # Ensure F-order - sorted_vals = cp.asfortranarray(sorted_vals) + # data is F-order; ravel("F") gives a flat C-contiguous view (no copy) + keys_in = data.astype(cp.float32, copy=False).ravel(order="F") + keys_out = cp.empty_like(keys_in) + + _wc.segmented_sort( + keys_in, + keys_out, + seg_offsets_gpu, + cub_temp, + n_items=n_items, + n_segments=n_segments, + stream=cp.cuda.get_current_stream().ptr, + ) - stream = cp.cuda.get_current_stream().ptr - _wc.tie_correction( - sorted_vals, correction, n_rows=n_rows, n_cols=n_cols, stream=stream + return cp.ndarray( + (n_rows, n_cols), dtype=cp.float32, memptr=keys_out.data, order="F" ) - return correction + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- def wilcoxon( @@ -94,14 +445,15 @@ def wilcoxon( chunk_size: int | None = None, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" - # Compute basic stats - uses Aggregate if on GPU, else defers to chunks - rg._basic_stats() X = rg.X - n_cells, n_total_genes = rg.X.shape + n_cells, n_total_genes = X.shape group_sizes = rg.group_sizes + # Stats via Aggregate for both OVR and OVO — decoupled from sort. + # Aggregate reads original dtype for precision, accumulates in float64. + rg._basic_stats() + if rg.ireference is not None: - # Compare each group against a specific reference group return _wilcoxon_with_reference( rg, X, @@ -111,7 +463,6 @@ def wilcoxon( use_continuity=use_continuity, chunk_size=chunk_size, ) - # Compare each group against "rest" (all other cells) return _wilcoxon_vs_rest( rg, X, @@ -120,10 +471,14 @@ def wilcoxon( group_sizes, tie_correct=tie_correct, use_continuity=use_continuity, - chunk_size=chunk_size, ) +# --------------------------------------------------------------------------- +# One-vs-rest +# --------------------------------------------------------------------------- + + def _wilcoxon_vs_rest( rg: _RankGenes, X, @@ -133,12 +488,16 @@ def _wilcoxon_vs_rest( *, tie_correct: bool, use_continuity: bool, - chunk_size: int | None, ) -> list[tuple[int, NDArray, NDArray]]: - """Wilcoxon test: each group vs rest of cells.""" + """Wilcoxon test: each group vs rest of cells. + + Dispatches to CSR, CSC, or dense streaming kernel based on input format. + No unnecessary format conversions. + """ + from rapids_singlecell._cuda import _wilcoxon_ovr_cuda as _ovr + n_groups = len(rg.groups_order) - # Warn for small groups for name, size in zip(rg.groups_order, group_sizes, strict=True): rest = n_cells - size if size <= MIN_GROUP_SIZE_WARNING or rest <= MIN_GROUP_SIZE_WARNING: @@ -149,73 +508,228 @@ def _wilcoxon_vs_rest( stacklevel=4, ) - # Build one-hot indicator matrix from group codes - codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int64) - group_matrix = cp.zeros((n_cells, n_groups), dtype=cp.float64) - valid_idx = cp.where(codes_gpu < n_groups)[0] - group_matrix[valid_idx, codes_gpu[valid_idx]] = 1.0 - + group_codes = rg.group_codes.astype(np.int32, copy=False) group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) rest_sizes = n_cells - group_sizes_dev - chunk_width = _choose_chunk_size(chunk_size) - - # Accumulate results per group - all_scores: dict[int, list] = {i: [] for i in range(n_groups)} - all_pvals: dict[int, list] = {i: [] for i in range(n_groups)} + # Determine host-streaming eligibility BEFORE transferring + host_csc = isinstance(X, sp.spmatrix | sp.sparray) and X.format == "csc" + host_csr = isinstance(X, sp.spmatrix | sp.sparray) and X.format == "csr" + host_dense = isinstance(X, np.ndarray) - # One-time CSR->CSC via fast parallel Numba kernel; _get_column_block - # then uses direct indptr pointer copy for each chunk. - if isinstance(X, sp.spmatrix | sp.sparray): - X = _fast_csr_to_csc(X) if X.format == "csr" else X.tocsc() - - for start in range(0, n_total_genes, chunk_width): - stop = min(start + chunk_width, n_total_genes) - - # Slice and convert to dense GPU array (F-order for column ops) - block = _get_column_block(X, start, stop) - - # Accumulate stats for this chunk - rg._accumulate_chunk_stats_vs_rest( - block, - start, - stop, - group_matrix=group_matrix, - group_sizes_dev=group_sizes_dev, - n_cells=n_cells, + if host_csc or host_csr or host_dense: + # Host-streaming: sort+rank stays on host→GPU per sub-batch. The + # kernel also emits per-group sum, sum-of-squares, and nonzero + # counts into caller-provided CuPy buffers, so means/vars/pts can + # be derived without uploading the full matrix. Outputs live on + # the GPU and feed directly into the z-score / p-value math below. + compute_vars = False + compute_nnz = rg.comp_pts + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + group_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + group_sq_sums = cp.empty( + (n_groups, n_total_genes) if compute_vars else (1, 1), + dtype=cp.float64, + ) + group_nnz = cp.empty( + (n_groups, n_total_genes) if compute_nnz else (1, 1), + dtype=cp.float64, ) - if tie_correct: - ranks, sorted_vals = _average_ranks(block, return_sorted=True) - tie_corr = _tie_correction(sorted_vals) + if host_csc: + group_sizes_np = group_sizes.astype(np.float64, copy=False) + # Native host dtype is preserved and uploaded once per sub-batch; + # a pre-sort kernel casts to float32 for the sort keys while + # accumulating stats in float64 from the original values. + _csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _ovr, "ovr_sparse_csc_host", X, support_idx64=False + ) + _csc_host_fn( + data_arr, + indices_arr, + X.indptr, + group_codes, + group_sizes_np, + rank_sums, + tie_corr, + group_sums, + group_sq_sums, + group_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + elif host_csr: + group_sizes_np = group_sizes.astype(np.float64, copy=False) + csr = X + if not csr.has_sorted_indices: + csr = csr.copy() + csr.sort_indices() + _csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _ovr, "ovr_sparse_csr_host", csr, support_idx64=True + ) + _csr_host_fn( + data_arr, + indices_arr, + csr.indptr, + group_codes, + group_sizes_np, + rank_sums, + tie_corr, + group_sums, + group_sq_sums, + group_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=STREAMING_SUB_BATCH, + ) else: - ranks = _average_ranks(block) - tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64) - - rank_sums = group_matrix.T @ ranks - expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 - variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] - variance *= (n_cells + 1) / 12.0 - std = cp.sqrt(variance) - diff = rank_sums - expected - if use_continuity: - diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / std - cp.nan_to_num(z, copy=False) - p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - - z_host = z.get() - p_host = p_values.get() - - for idx in range(n_groups): - all_scores[idx].append(z_host[idx]) - all_pvals[idx].append(p_host[idx]) - - # Collect results per group - return [ - (gi, np.concatenate(all_scores[gi]), np.concatenate(all_pvals[gi])) - for gi in range(n_groups) - ] + is_f64 = X.dtype == np.float64 + _dense_host_fn = ( + _ovr.ovr_streaming_dense_host_f64 + if is_f64 + else _ovr.ovr_streaming_dense_host + ) + block = X if is_f64 else X.astype(np.float32, copy=False) + _dense_host_fn( + np.asfortranarray(block), + group_codes, + rank_sums, + tie_corr, + group_sums, + group_sq_sums, + group_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + + if rg._compute_stats_in_chunks: + total_sums, total_sq_sums, total_nnz = _host_ovr_totals_if_needed( + X, + group_codes, + n_groups, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + ) + _fill_basic_stats_from_accumulators( + rg, + group_sums, + group_sq_sums, + group_nnz, + group_sizes.astype(np.float64, copy=False), + n_cells=n_cells, + compute_vars=compute_vars, + total_sums=total_sums, + total_sq_sums=total_sq_sums, + total_nnz=total_nnz, + ) + else: + # GPU data → use native GPU kernels. + X_gpu = _to_gpu_native(X, n_cells, n_total_genes) + + if rg._compute_stats_in_chunks: + rg.X = X_gpu + rg._compute_stats_in_chunks = False + rg._basic_stats() + + group_codes_gpu = cp.asarray(group_codes) + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + + if cpsp.isspmatrix_csc(X_gpu): + # Sparse-aware path: sort only stored nonzeros, + # handle zeros analytically. + csc_data = X_gpu.data.astype(cp.float32, copy=False) + csc_indices = X_gpu.indices.astype(cp.int32, copy=False) + csc_indptr = X_gpu.indptr.astype(cp.int32, copy=False) + cp.cuda.get_current_stream().synchronize() + _ovr.ovr_sparse_csc( + csc_data, + csc_indices, + csc_indptr, + group_codes_gpu, + group_sizes_dev, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + elif cpsp.isspmatrix_csr(X_gpu): + csr_gpu = X_gpu + if not csr_gpu.has_sorted_indices: + csr_gpu = csr_gpu.copy() + csr_gpu.sort_indices() + csr_data = csr_gpu.data.astype(cp.float32, copy=False) + csr_indices = csr_gpu.indices.astype(cp.int32, copy=False) + csr_indptr = csr_gpu.indptr.astype(cp.int32, copy=False) + cp.cuda.get_current_stream().synchronize() + _ovr.ovr_sparse_csr( + csr_data, + csr_indices, + csr_indptr, + group_codes_gpu, + group_sizes_dev, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + else: + dense_f32 = cp.asfortranarray(X_gpu.astype(cp.float32, copy=False)) + cp.cuda.get_current_stream().synchronize() + _ovr.ovr_streaming( + dense_f32, + group_codes_gpu, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + + # Z-scores + p-values (vectorised) + expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 + variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] + variance *= (n_cells + 1) / 12.0 + std = cp.sqrt(variance) + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / std + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + + all_z = z.astype(cp.float32).get() + all_p = p_values.get() + + return [(gi, all_z[gi], all_p[gi]) for gi in range(n_groups)] + + +# --------------------------------------------------------------------------- +# One-vs-reference +# --------------------------------------------------------------------------- def _wilcoxon_with_reference( @@ -228,97 +742,365 @@ def _wilcoxon_with_reference( use_continuity: bool, chunk_size: int | None, ) -> list[tuple[int, NDArray, NDArray]]: - """Wilcoxon test: each group vs a specific reference group.""" + """Wilcoxon test: each group vs a specific reference group. + + All test groups are processed in a single batched streaming kernel, + eliminating per-group kernel launch overhead. + """ + + n_cells = X.shape[0] + n_groups = len(rg.groups_order) + ireference = rg.ireference + n_ref = int(group_sizes[ireference]) codes = rg.group_codes - n_ref = int(group_sizes[rg.ireference]) - mask_ref = codes == rg.ireference - results: list[tuple[int, NDArray, NDArray]] = [] + # ---- build row-index arrays via CSR-style cat offsets (O(n)) ---- + # scipy's coo→csr conversion is a vectorised counting sort — ~15× faster + # than np.argsort(stable) on this shape (1.8 ms vs 27 ms for 534 k cells). + # indptr[g] .. indptr[g+1] bracket the rows of group g in indices. + n_rows_incl_sentinel = n_groups + 1 # last slot holds "unselected" rows + _csr = sp.coo_matrix( + ( + np.ones(n_cells, dtype=np.int8), + (codes, np.arange(n_cells, dtype=np.int32)), + ), + shape=(n_rows_incl_sentinel, n_cells), + ).tocsr() + offsets_full = _csr.indptr # int32, length n_groups + 2 + sorted_rows = _csr.indices # int32, length n_cells - for group_index in range(len(rg.groups_order)): - if group_index == rg.ireference: - continue + test_group_indices: list[int] = [gi for gi in range(n_groups) if gi != ireference] + if not test_group_indices: + return [] + test_group_indices_np = np.asarray(test_group_indices, dtype=np.intp) - n_group = int(group_sizes[group_index]) - n_combined = n_group + n_ref + offsets = [0] + all_grp_rows_parts: list[np.ndarray] = [] + for gi in test_group_indices: + g_start = int(offsets_full[gi]) + g_end = int(offsets_full[gi + 1]) + all_grp_rows_parts.append(sorted_rows[g_start:g_end]) + offsets.append(offsets[-1] + (g_end - g_start)) - # Warn for small groups - if n_group <= MIN_GROUP_SIZE_WARNING or n_ref <= MIN_GROUP_SIZE_WARNING: - warnings.warn( - f"Group {rg.groups_order[group_index]} has size {n_group} " - f"(reference {n_ref}); normal approximation " - "of the Wilcoxon statistic may be inaccurate.", - RuntimeWarning, - stacklevel=4, + all_grp_row_ids_np = np.concatenate(all_grp_rows_parts) + n_test = len(test_group_indices) + n_all_grp = len(all_grp_row_ids_np) + ref_row_ids_np = sorted_rows[ + int(offsets_full[ireference]) : int(offsets_full[ireference + 1]) + ] + + # ---- warn for small groups (single aggregated warning for the batch + # rather than one per group — emitting a warning per group is O(n_test) + # Python overhead that dominates on workloads with thousands of groups). + small_test = [ + str(rg.groups_order[gi]) + for gi in test_group_indices + if int(group_sizes[gi]) <= MIN_GROUP_SIZE_WARNING + ] + ref_small = n_ref <= MIN_GROUP_SIZE_WARNING + if small_test or ref_small: + parts = [] + if small_test: + parts.append( + f"{len(small_test)} test group(s) have size " + f"<= {MIN_GROUP_SIZE_WARNING} (first few: " + f"{', '.join(small_test[:5])}{'...' if len(small_test) > 5 else ''})" ) + if ref_small: + parts.append(f"reference has size {n_ref}") + warnings.warn( + f"Small groups detected: {'; '.join(parts)}. normal " + "approximation of the Wilcoxon statistic may be inaccurate.", + RuntimeWarning, + stacklevel=4, + ) + + test_sizes = cp.asarray( + group_sizes[test_group_indices_np].astype(np.float64, copy=False) + ) - # Combined mask: group + reference - mask_obs = codes == group_index - mask_combined = mask_obs | mask_ref + offsets_np = np.asarray(offsets, dtype=np.int32) - # Subset matrix ONCE before chunking (10x faster than filtering each chunk) - X_subset = X[mask_combined, :] + # ---- host-streaming paths: skip bulk transfer ---- + host_sparse = isinstance(X, sp.spmatrix | sp.sparray) + host_dense = isinstance(X, np.ndarray) + + # ---- build row maps only for paths that need original-row lookup ---- + ref_row_map_np = grp_row_map_np = None + if (host_sparse and X.format == "csc") or (not host_sparse and not host_dense): + ref_row_map_np = np.full(n_cells, -1, dtype=np.int32) + ref_row_map_np[ref_row_ids_np] = np.arange(n_ref, dtype=np.int32) + grp_row_map_np = np.full(n_cells, -1, dtype=np.int32) + grp_row_map_np[all_grp_row_ids_np] = np.arange(n_all_grp, dtype=np.int32) + + if host_sparse or host_dense: + # Output buffers live on the GPU (caller-provided CuPy memory); + # kernels write directly into them, and rank_sums / tie_corr + # feed the z-score math below without any H2D → H2D round-trip. + rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) + + # Stats slots: 0..n_test-1 = test groups, slot n_test = reference. + # Unselected cells carry the sentinel (n_groups_stats) which the + # kernel skips. + n_groups_stats = n_test + 1 + compute_vars = False + compute_nnz = rg.comp_pts + stats_code_lookup = np.full(n_groups + 1, n_groups_stats, dtype=np.int32) + stats_code_lookup[test_group_indices_np] = np.arange(n_test, dtype=np.int32) + stats_code_lookup[ireference] = n_test + stats_codes_np = stats_code_lookup[codes] + + group_sums = cp.empty((n_groups_stats, n_total_genes), dtype=cp.float64) + group_sq_sums = cp.empty( + (n_groups_stats, n_total_genes) if compute_vars else (1,), + dtype=cp.float64, + ) + group_nnz = cp.empty( + (n_groups_stats, n_total_genes) if compute_nnz else (1,), + dtype=cp.float64, + ) + + if host_sparse and X.format == "csc": + _csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wc, "ovo_streaming_csc_host", X, support_idx64=True + ) + _csc_host_fn( + data_arr, + indices_arr, + X.indptr, + ref_row_map_np, + grp_row_map_np, + offsets_np, + stats_codes_np, + rank_sums, + tie_corr_arr, + group_sums, + group_sq_sums, + group_nnz, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_test, + n_groups_stats=n_groups_stats, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + elif host_sparse and X.format == "csr": + csr = X + if not csr.has_sorted_indices: + csr = csr.copy() + csr.sort_indices() - # One-time CSR->CSC via fast parallel Numba kernel - if isinstance(X_subset, sp.spmatrix | sp.sparray): - X_subset = ( - _fast_csr_to_csc(X_subset) - if X_subset.format == "csr" - else X_subset.tocsc() + # Zero-copy mapped: pin full CSR, upload indptr + row_ids, GPU + # kernels gather per-pack rows via UVA reads. + _csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wc, "ovo_streaming_csr_host", csr, support_idx64=True + ) + _csr_host_fn( + data_arr, + indices_arr, + csr.indptr, + ref_row_ids_np.astype(np.int32, copy=False), + all_grp_row_ids_np.astype(np.int32, copy=False), + offsets_np, + rank_sums, + tie_corr_arr, + group_sums, + group_sq_sums, + group_nnz, + n_full_rows=n_cells, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_test=n_test, + n_groups_stats=n_groups_stats, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + elif host_sparse: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." + ) + else: + is_f64 = X.dtype == np.float64 + _dense_host_fn = ( + _wc.ovo_streaming_dense_host_f64 + if is_f64 + else _wc.ovo_streaming_dense_host + ) + block = X if is_f64 else X.astype(np.float32, copy=False) + _dense_host_fn( + np.asfortranarray(block), + ref_row_ids_np.astype(np.int32, copy=False), + all_grp_row_ids_np.astype(np.int32, copy=False), + offsets_np, + stats_codes_np, + rank_sums, + tie_corr_arr, + group_sums, + group_sq_sums, + group_nnz, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_test, + n_groups_stats=n_groups_stats, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=STREAMING_SUB_BATCH, ) - # Within the combined array, True = group cell, False = reference cell - group_mask_gpu = cp.asarray(mask_obs[mask_combined]) + if rg._compute_stats_in_chunks: + _fill_ovo_stats_from_accumulators( + rg, + group_sums, + group_sq_sums, + group_nnz, + group_sizes=group_sizes, + test_group_indices=test_group_indices, + n_ref=n_ref, + compute_vars=compute_vars, + ) - chunk_width = _choose_chunk_size(chunk_size) + else: + # ---- GPU path: transfer once, then dispatch ---- + X_gpu = _to_gpu_native(X, n_cells, n_total_genes) + grp_offsets_gpu = cp.asarray(offsets_np, dtype=cp.int32) - # Pre-allocate output arrays - scores = np.empty(n_total_genes, dtype=np.float64) - pvals = np.empty(n_total_genes, dtype=np.float64) + if rg._compute_stats_in_chunks: + rg.X = X_gpu + rg._compute_stats_in_chunks = False + rg._basic_stats() - for start in range(0, n_total_genes, chunk_width): - stop = min(start + chunk_width, n_total_genes) + ref_row_ids_gpu = cp.asarray(ref_row_ids_np, dtype=cp.int32) + all_grp_row_ids_gpu = cp.asarray(all_grp_row_ids_np, dtype=cp.int32) - # Get block for combined cells only - block = _get_column_block(X_subset, start, stop) + rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.empty((n_test, n_total_genes), dtype=cp.float64) - # Accumulate stats for this chunk - rg._accumulate_chunk_stats_with_ref( - block, - start, - stop, - group_index=group_index, - group_mask_gpu=group_mask_gpu, - n_group=n_group, + if cpsp.isspmatrix_csc(X_gpu): + ref_row_map = cp.asarray(ref_row_map_np) + grp_row_map = cp.asarray(grp_row_map_np) + csc_data = X_gpu.data.astype(cp.float32, copy=False) + csc_indices = X_gpu.indices.astype(cp.int32, copy=False) + csc_indptr = X_gpu.indptr.astype(cp.int32, copy=False) + cp.cuda.get_current_stream().synchronize() + _wc.ovo_streaming_csc( + csc_data, + csc_indices, + csc_indptr, + ref_row_map, + grp_row_map, + grp_offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + elif cpsp.isspmatrix_csr(X_gpu): + # CSR-native: extract ref/grp rows directly + csr_gpu = X_gpu + if not csr_gpu.has_sorted_indices: + csr_gpu = csr_gpu.copy() + csr_gpu.sort_indices() + csr_data = csr_gpu.data.astype(cp.float32, copy=False) + csr_indices = csr_gpu.indices.astype(cp.int32, copy=False) + csr_indptr = csr_gpu.indptr.astype(cp.int32, copy=False) + cp.cuda.get_current_stream().synchronize() + _wc.ovo_streaming_csr( + csr_data, + csr_indices, + csr_indptr, + ref_row_ids_gpu, + all_grp_row_ids_gpu, + grp_offsets_gpu, + rank_sums, + tie_corr_arr, n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, ) + elif cpsp.issparse(X_gpu): + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + "full-matrix GPU sparse conversion." + ) + else: + # Dense device data is already resident, but extracting all genes + # for all reference/test rows can still blow up memory. Preserve + # the public chunk_size escape hatch by materializing bounded + # column blocks and stitching the CUDA outputs together. + dense_chunk = _resolve_chunk_size(chunk_size, n_total_genes) + for start in range(0, n_total_genes, dense_chunk): + stop = min(start + dense_chunk, n_total_genes) + sb_cols = stop - start + ref_block = _extract_dense_block(X_gpu, ref_row_ids_gpu, start, stop) + grp_block = _extract_dense_block( + X_gpu, all_grp_row_ids_gpu, start, stop + ) + ref_sorted = _segmented_sort_columns( + ref_block, + np.array([0, n_ref], dtype=np.int32), + n_ref, + sb_cols, + 1, + ) + grp_f32 = cp.asfortranarray(grp_block.astype(cp.float32, copy=False)) + sub_rank_sums = cp.empty((n_test, sb_cols), dtype=cp.float64) + sub_tie_corr = cp.empty((n_test, sb_cols), dtype=cp.float64) + cp.cuda.get_current_stream().synchronize() + _wc.ovo_streaming( + ref_sorted, + grp_f32, + grp_offsets_gpu, + sub_rank_sums, + sub_tie_corr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=sb_cols, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=STREAMING_SUB_BATCH, + ) + rank_sums[:, start:stop] = sub_rank_sums + tie_corr_arr[:, start:stop] = sub_tie_corr - # Ranks for combined group+reference cells - if tie_correct: - ranks, sorted_vals = _average_ranks(block, return_sorted=True) - tie_corr = _tie_correction(sorted_vals) - else: - ranks = _average_ranks(block) - tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64) - - # Rank sum for the group - rank_sums = (ranks * group_mask_gpu[:, None]).sum(axis=0) - - # Wilcoxon z-score formula for two groups - expected = n_group * (n_combined + 1) / 2.0 - variance = tie_corr * n_group * n_ref * (n_combined + 1) / 12.0 - std = cp.sqrt(variance) - diff = rank_sums - expected - if use_continuity: - diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / std - cp.nan_to_num(z, copy=False) - p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) - - # Fill pre-allocated arrays - scores[start:stop] = z.get() - pvals[start:stop] = p_values.get() - - results.append((group_index, scores, pvals)) - - return results + # ---- z-scores & p-values (vectorised) ---- + n_combined = test_sizes + n_ref + expected = test_sizes * (n_combined + 1) / 2.0 + variance = test_sizes * n_ref * (n_combined + 1) / 12.0 + if tie_correct: + variance = variance[:, None] * tie_corr_arr + else: + variance = variance[:, None] + + diff = rank_sums - expected[:, None] + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / cp.sqrt(variance) + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + + # Downcast scores to float32 on GPU (the scanpy output dtype); keeps + # p-values in float64 for downstream BH correction precision. Moving + # the cast off CPU saves ~150 ms per stat per call on wide workloads. + all_z = z.astype(cp.float32).get() + all_p = p_values.get() + + return [(gi, all_z[ti], all_p[ti]) for ti, gi in enumerate(test_group_indices)] diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 0c6844da..aba9ac4a 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -1,16 +1,34 @@ from __future__ import annotations import cupy as cp +import cupyx.scipy.sparse as cpsp import numpy as np import pandas as pd import pytest import scanpy as sc import scipy.sparse as sp -from scipy.stats import mannwhitneyu, rankdata, tiecorrect +from scipy.stats import mannwhitneyu import rapids_singlecell as rsc +def _to_format(X_dense, fmt): + """Convert dense numpy array to the specified format.""" + if fmt == "numpy_dense": + return np.asarray(X_dense) + if fmt == "scipy_csr": + return sp.csr_matrix(X_dense) + if fmt == "scipy_csc": + return sp.csc_matrix(X_dense) + if fmt == "cupy_dense": + return cp.asarray(X_dense) + if fmt == "cupy_csr": + return cpsp.csr_matrix(cp.asarray(X_dense)) + if fmt == "cupy_csc": + return cpsp.csc_matrix(cp.asarray(X_dense)) + raise ValueError(f"Unknown format: {fmt}") + + @pytest.mark.parametrize("reference", ["rest", "1"]) @pytest.mark.parametrize("tie_correct", [True, False]) @pytest.mark.parametrize("sparse", [True, False]) @@ -148,6 +166,118 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference): assert np.all(adjusted <= 1.0) +@pytest.mark.parametrize( + "fmt", + [ + pytest.param("scipy_csr", id="host_csr"), + pytest.param("scipy_csc", id="host_csc"), + pytest.param("cupy_dense", id="device_dense"), + ], +) +def test_wilcoxon_subset_rest_stats_match_scanpy(fmt): + """groups=... with reference='rest' must use all other cells for stats.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=160) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "groups": ["0", "2"], + "reference": "rest", + "pts": True, + "n_genes": 6, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + + for key in ("pts", "pts_rest"): + gpu_pts = gpu_result[key] + cpu_pts = cpu_result[key] + for col in gpu_pts.columns: + np.testing.assert_allclose( + gpu_pts[col].values, cpu_pts[col].values, rtol=1e-13, atol=1e-15 + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc"]) +def test_wilcoxon_zero_nnz_host_sparse_does_not_crash(reference, fmt): + obs = pd.DataFrame( + { + "group": pd.Categorical( + ["0"] * 4 + ["1"] * 4 + ["2"] * 4, + categories=["0", "1", "2"], + ) + } + ) + adata = sc.AnnData( + X=_to_format(np.zeros((12, 5), dtype=np.float32), fmt), + obs=obs, + var=pd.DataFrame(index=[f"g{i}" for i in range(5)]), + ) + + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference=reference, + pts=True, + ) + + result = adata.uns["rank_genes_groups"] + for field in ("scores", "pvals"): + for group in result[field].dtype.names: + assert np.all(np.isfinite(np.asarray(result[field][group], dtype=float))) + + +def test_wilcoxon_dense_ovo_chunk_size_matches_unchunked(): + np.random.seed(42) + base = sc.datasets.blobs(n_variables=9, n_centers=3, n_observations=120) + base.obs["blobs"] = base.obs["blobs"].astype("category") + unchunked = base.copy() + chunked = base.copy() + unchunked.X = cp.asarray(unchunked.X) + chunked.X = cp.asarray(chunked.X) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": "1", + "tie_correct": True, + "n_genes": 9, + } + rsc.tl.rank_genes_groups(unchunked, **kw) + rsc.tl.rank_genes_groups(chunked, **kw, chunk_size=2) + + for field in ("scores", "pvals", "pvals_adj", "logfoldchanges"): + for group in unchunked.uns["rank_genes_groups"][field].dtype.names: + np.testing.assert_allclose( + np.asarray(unchunked.uns["rank_genes_groups"][field][group], float), + np.asarray(chunked.uns["rank_genes_groups"][field][group], float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + + @pytest.mark.parametrize( "reference_before,reference_after", [("rest", "rest"), ("1", "One")], @@ -444,185 +574,603 @@ def test_sparse_matches_dense(self, perturbation_adata, sparse): # ============================================================================ -# Tests for ranking and tie correction kernels (edge cases from scipy) +# Matrix format coverage: all dispatch paths must agree # ============================================================================ -class TestRankingKernel: - """Tests for _average_ranks based on scipy.stats.rankdata edge cases.""" +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize( + "fmt", + [ + pytest.param("scipy_csc", id="scipy_csc"), + pytest.param("cupy_dense", id="cupy_dense"), + pytest.param("cupy_csr", id="cupy_csr"), + pytest.param("cupy_csc", id="cupy_csc"), + ], +) +def test_format_matches_scanpy(reference, fmt): + """Every matrix format matches scanpy output.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + + assert gpu_result["names"].dtype.names == cpu_result["names"].dtype.names + for group in gpu_result["names"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + + for field in ("scores", "pvals", "logfoldchanges", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) + + +# ============================================================================ +# Negative values: centered/scaled data must match scanpy across all formats +# ============================================================================ - @pytest.fixture - def average_ranks(self): - """Import the ranking function.""" - from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( - _average_ranks, - ) - return _average_ranks - - @staticmethod - def _to_gpu(values): - """Convert 1D values to GPU column matrix with F-order.""" - arr = np.asarray(values, dtype=np.float64).reshape(-1, 1) - return cp.asarray(arr, order="F") - - def test_basic_ranking(self, average_ranks): - """Test basic average ranking on simple data.""" - values = [3.0, 1.0, 2.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_all_ties(self, average_ranks): - """All identical values should get the average rank.""" - values = [5.0, 5.0, 5.0, 5.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_no_ties(self, average_ranks): - """All unique values should get sequential ranks.""" - values = [1.0, 2.0, 3.0, 4.0, 5.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_mixed_ties(self, average_ranks): - """Mix of ties and unique values.""" - values = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_negative_values(self, average_ranks): - """Test with negative values.""" - values = [-3.0, -1.0, -2.0, 0.0, 1.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_single_element(self, average_ranks): - """Single element should have rank 1.""" - values = [42.0] - result = average_ranks(self._to_gpu(values)) - np.testing.assert_allclose(result.get().flatten(), [1.0]) - - def test_two_elements_tied(self, average_ranks): - """Two tied elements should both have rank 1.5.""" - values = [7.0, 7.0] - result = average_ranks(self._to_gpu(values)) - np.testing.assert_allclose(result.get().flatten(), [1.5, 1.5]) - - def test_multiple_columns(self, average_ranks): - """Test ranking across multiple columns independently.""" - col0 = [3.0, 1.0, 2.0] - col1 = [1.0, 1.0, 2.0] - data = np.column_stack([col0, col1]).astype(np.float64) - result = average_ranks(cp.asarray(data, order="F")) - - np.testing.assert_allclose(result.get()[:, 0], rankdata(col0, method="average")) - np.testing.assert_allclose(result.get()[:, 1], rankdata(col1, method="average")) - - -class TestTieCorrectionKernel: - """Tests for _tie_correction based on scipy.stats.tiecorrect edge cases.""" - - @pytest.fixture - def tie_correction(self): - """Import the tie correction function and ranking function.""" - from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( - _average_ranks, - _tie_correction, +def _make_centered_adata(n_obs=200, n_vars=8, n_centers=3, seed=42): + """Create AnnData with centered (mean-zero) data containing negatives.""" + np.random.seed(seed) + adata = sc.datasets.blobs( + n_variables=n_vars, n_centers=n_centers, n_observations=n_obs + ) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + # Center each gene to produce negative values + adata.X = adata.X - adata.X.mean(axis=0) + return adata + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize( + "fmt", + [ + pytest.param("scipy_csc", id="scipy_csc"), + pytest.param("cupy_dense", id="cupy_dense"), + pytest.param("cupy_csr", id="cupy_csr"), + pytest.param("cupy_csc", id="cupy_csc"), + ], +) +def test_negative_values_match_scanpy(reference, fmt): + """Centered data (with negatives) matches scanpy across all formats.""" + adata_gpu = _make_centered_adata() + adata_cpu = adata_gpu.copy() + + # Verify data actually has negatives + assert adata_gpu.X.min() < 0 + + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + + for group in gpu_result["names"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + + for field in ("scores", "pvals", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_negative_sparse_matches_dense(reference): + """Sparse and dense paths give identical results for centered data.""" + adata_dense = _make_centered_adata() + adata_csr = adata_dense.copy() + adata_csc = adata_dense.copy() + + adata_csr.X = cpsp.csr_matrix(cp.asarray(adata_dense.X)) + adata_csc.X = cpsp.csc_matrix(cp.asarray(adata_dense.X)) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata_dense, **kw) + rsc.tl.rank_genes_groups(adata_csr, **kw) + rsc.tl.rank_genes_groups(adata_csc, **kw) + + dense_result = adata_dense.uns["rank_genes_groups"] + csr_result = adata_csr.uns["rank_genes_groups"] + csc_result = adata_csc.uns["rank_genes_groups"] + + for field in ("scores", "pvals"): + for group in dense_result[field].dtype.names: + dense_vals = np.asarray(dense_result[field][group], dtype=float) + csr_vals = np.asarray(csr_result[field][group], dtype=float) + csc_vals = np.asarray(csc_result[field][group], dtype=float) + np.testing.assert_allclose(csr_vals, dense_vals, rtol=1e-13, atol=1e-15) + np.testing.assert_allclose(csc_vals, dense_vals, rtol=1e-13, atol=1e-15) + + +# ============================================================================ +# pre_load: GPU transfer before wilcoxon must match default (lazy transfer) +# ============================================================================ + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_pre_load_matches_scanpy(reference): + """pre_load=True matches scanpy output.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw, pre_load=True) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + + assert gpu_result["names"].dtype.names == cpu_result["names"].dtype.names + for group in gpu_result["names"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + + for field in ("scores", "pvals", "logfoldchanges", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) + + +# ============================================================================ +# use_continuity with reference="rest" (OVR mode) +# ============================================================================ + + +def test_use_continuity_vs_rest_changes_scores(): + """use_continuity with reference='rest' adjusts z-scores toward zero.""" + np.random.seed(42) + adata_no = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + adata_no.obs["blobs"] = adata_no.obs["blobs"].astype("category") + adata_yes = adata_no.copy() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": "rest", + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata_no, **kw, use_continuity=False) + rsc.tl.rank_genes_groups(adata_yes, **kw, use_continuity=True) + + for group in adata_no.uns["rank_genes_groups"]["scores"].dtype.names: + scores_no = np.asarray( + adata_no.uns["rank_genes_groups"]["scores"][group], dtype=float + ) + scores_yes = np.asarray( + adata_yes.uns["rank_genes_groups"]["scores"][group], dtype=float + ) + # Continuity correction moves z-scores toward zero + assert np.all(np.abs(scores_yes) <= np.abs(scores_no) + 1e-15), ( + f"Group {group}: continuity-corrected |z| should be <= uncorrected |z|" + ) + # p-values should be valid + pvals = np.asarray( + adata_yes.uns["rank_genes_groups"]["pvals"][group], dtype=float ) + assert np.all(pvals >= 0) + assert np.all(pvals <= 1) - return _tie_correction, _average_ranks - @staticmethod - def _to_gpu(values): - """Convert 1D values to GPU column matrix with F-order.""" - arr = np.asarray(values, dtype=np.float64).reshape(-1, 1) - return cp.asarray(arr, order="F") +# ============================================================================ +# mask_var: gene subsetting +# ============================================================================ - def test_no_ties(self, tie_correction): - """No ties should give correction factor 1.0.""" - _tie_correction, _average_ranks = tie_correction - values = [1.0, 2.0, 3.0, 4.0, 5.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_mask_var_matches_scanpy(reference): + """mask_var restricts genes and matches scanpy output.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=200) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) + mask = np.zeros(adata_gpu.n_vars, dtype=bool) + mask[:5] = True - def test_all_ties(self, tie_correction): - """All tied values should give correction factor 0.0.""" - _tie_correction, _average_ranks = tie_correction + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "mask_var": mask, + "reference": reference, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) - values = [5.0, 5.0, 5.0, 5.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) + for group in gpu_result["names"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + names = set(gpu_result["names"][group]) + expected_genes = set(adata_gpu.var_names[:5]) + assert names <= expected_genes - def test_mixed_ties(self, tie_correction): - """Mix of ties should give intermediate correction factor.""" - _tie_correction, _average_ranks = tie_correction + for field in ("scores", "pvals", "logfoldchanges", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) - values = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) +def test_mask_var_string_key(): + """mask_var accepts a string key from adata.var.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=200) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() - def test_two_elements_tied(self, tie_correction): - """Two tied elements.""" - _tie_correction, _average_ranks = tie_correction + adata_gpu.var["highly_variable"] = [True] * 5 + [False] * 5 + adata_cpu.var["highly_variable"] = [True] * 5 + [False] * 5 - values = [7.0, 7.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "mask_var": "highly_variable", + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] - def test_single_element(self, tie_correction): - """Single element should give correction factor 1.0.""" - _tie_correction, _average_ranks = tie_correction + for group in gpu_result["names"].dtype.names: + assert len(gpu_result["names"][group]) == 5 + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) - values = [42.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) + for field in ("scores", "pvals"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) - # Single element: n^3 - n = 0, so formula gives 1.0 - np.testing.assert_allclose(result.get()[0], 1.0, rtol=1e-10) - def test_multiple_columns(self, tie_correction): - """Test tie correction across multiple columns independently.""" - _tie_correction, _average_ranks = tie_correction +# ============================================================================ +# key_added: custom output key +# ============================================================================ - col0 = [1.0, 2.0, 3.0] # No ties - col1 = [5.0, 5.0, 5.0] # All ties - data = np.column_stack([col0, col1]).astype(np.float64) - _, sorted_vals = _average_ranks(cp.asarray(data, order="F"), return_sorted=True) - result = _tie_correction(sorted_vals) - np.testing.assert_allclose( - result.get()[0], tiecorrect(rankdata(col0)), rtol=1e-10 +def test_key_added_matches_scanpy(): + """key_added stores results under a custom key, matching scanpy.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + + rsc.tl.rank_genes_groups( + adata_gpu, "blobs", method="wilcoxon", use_raw=False, key_added="my_de" + ) + sc.tl.rank_genes_groups( + adata_cpu, "blobs", method="wilcoxon", use_raw=False, key_added="my_de" + ) + + assert "my_de" in adata_gpu.uns + assert "rank_genes_groups" not in adata_gpu.uns + + gpu_result = adata_gpu.uns["my_de"] + cpu_result = adata_cpu.uns["my_de"] + + for field in ("scores", "pvals", "logfoldchanges", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) + + +# ============================================================================ +# rankby_abs: ranking by absolute score +# ============================================================================ + + +def test_rankby_abs_matches_scanpy(): + """rankby_abs ranks genes by |score| and matches scanpy.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "n_genes": 3, + "rankby_abs": True, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + + for group in gpu_result["names"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + + # Scores should be sorted by absolute value (descending) + abs_scores = np.abs(np.asarray(gpu_result["scores"][group], dtype=float)) + assert np.all(abs_scores[:-1] >= abs_scores[1:]), ( + f"Group {group}: abs scores not sorted descending" ) - np.testing.assert_allclose( - result.get()[1], tiecorrect(rankdata(col1)), rtol=1e-10 + + for field in ("scores", "pvals", "logfoldchanges", "pvals_adj"): + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + ) + + +# ============================================================================ +# Small group warning +# ============================================================================ + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_small_group_warning(reference): + """Groups with <=25 cells trigger a RuntimeWarning.""" + np.random.seed(42) + n_large = 100 + n_small = 20 # below MIN_GROUP_SIZE_WARNING = 25 + n_cells = n_large + n_small + n_large + + adata = sc.AnnData( + X=np.random.randn(n_cells, 5).astype(np.float32), + obs=pd.DataFrame( + { + "group": pd.Categorical( + ["0"] * n_large + ["1"] * n_small + ["2"] * n_large, + categories=["0", "1", "2"], + ), + } + ), + ) + + with pytest.warns(RuntimeWarning, match="normal approximation"): + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference=reference, + ) + + +# ============================================================================ +# Singlet group rejection +# ============================================================================ + + +def test_singlet_group_raises(): + """A group with only 1 cell raises ValueError.""" + np.random.seed(42) + adata = sc.AnnData( + X=np.random.randn(101, 5).astype(np.float32), + obs=pd.DataFrame( + { + "group": pd.Categorical( + ["big"] * 100 + ["tiny"], + categories=["big", "tiny"], + ), + } + ), + ) + + with pytest.raises(ValueError, match="only contain one sample"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) + + +# ============================================================================ +# Invalid reference raises +# ============================================================================ + + +def test_invalid_reference_raises(): + """A reference not in the categories raises ValueError.""" + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=100) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + + with pytest.raises(ValueError, match="reference = nonexistent"): + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon", + use_raw=False, + reference="nonexistent", + ) + + +# ============================================================================ +# String group parameter raises +# ============================================================================ + + +def test_string_groups_raises(): + """Passing a bare string as groups raises ValueError.""" + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=100) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + + with pytest.raises(ValueError, match="Specify a sequence"): + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon", + use_raw=False, + groups="0", ) - def test_large_tie_groups(self, tie_correction): - """Test with large tie groups.""" - _tie_correction, _average_ranks = tie_correction - # 50 values of 1, 50 values of 2 (non-multiple of 32 to test warp handling) - values = [1.0] * 50 + [2.0] * 50 - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) +# ============================================================================ +# Many groups with reference: 5+ groups, one vs reference +# ============================================================================ + - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) +def test_many_groups_with_reference(): + """Wilcoxon with many test groups vs a reference produces correct output.""" + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=6, n_centers=5, n_observations=300) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + adata_cpu = adata.copy() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": "0", + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + result = adata.uns["rank_genes_groups"] + assert "0" not in result["names"].dtype.names + expected = {"1", "2", "3", "4"} + assert set(result["names"].dtype.names) == expected + + for field in ("scores", "pvals"): + for group in result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(result[field][group], dtype=float), + np.asarray( + adata_cpu.uns["rank_genes_groups"][field][group], dtype=float + ), + rtol=1e-13, + atol=1e-15, + ) + + +# ============================================================================ +# Group subsetting with unselected cells (OVR): unselected cells in "rest" +# ============================================================================ + + +def test_group_subset_vs_rest_unselected_cells(): + """With groups subset and reference='rest', unselected cells go into rest.""" + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=200) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + adata_cpu = adata.copy() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "groups": ["0", "2"], + "reference": "rest", + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + for field in ("scores", "pvals"): + for group in adata.uns["rank_genes_groups"][field].dtype.names: + np.testing.assert_allclose( + np.asarray(adata.uns["rank_genes_groups"][field][group], dtype=float), + np.asarray( + adata_cpu.uns["rank_genes_groups"][field][group], dtype=float + ), + rtol=1e-13, + atol=1e-15, + ) + + +# ============================================================================ +# Group subsetting with reference: unselected cells excluded from pairwise +# ============================================================================ + + +def test_group_subset_with_reference_unselected_cells(): + """With groups subset and a reference, unselected cells are excluded.""" + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=200) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + adata_cpu = adata.copy() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "groups": ["0", "1", "2"], + "reference": "1", + "tie_correct": True, + } + rsc.tl.rank_genes_groups(adata, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + result = adata.uns["rank_genes_groups"] + assert "1" not in result["names"].dtype.names + assert set(result["names"].dtype.names) == {"0", "2"} + + for field in ("scores", "pvals"): + for group in result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(result[field][group], dtype=float), + np.asarray( + adata_cpu.uns["rank_genes_groups"][field][group], dtype=float + ), + rtol=1e-13, + atol=1e-15, + ) diff --git a/tests/test_rank_genes_groups_wilcoxon_binned.py b/tests/test_rank_genes_groups_wilcoxon_binned.py index 96af0711..1d2e469b 100644 --- a/tests/test_rank_genes_groups_wilcoxon_binned.py +++ b/tests/test_rank_genes_groups_wilcoxon_binned.py @@ -524,3 +524,220 @@ def test_top_genes_match_scipy(adata_blobs): scipy_top = set(adata_blobs.var_names[np.argsort(pvals)[:n_top]]) overlap = len(binned_top & scipy_top) assert overlap >= n_top - 1, f"Group {group}: {overlap}/{n_top} overlap" + + +# ============================================================================ +# tie_correct and use_continuity coverage +# ============================================================================ + + +class TestWilcoxonBinnedCorrections: + """Tests for tie_correct and use_continuity branches.""" + + def test_tie_correct_changes_scores(self, adata_blobs): + """tie_correct=True should produce different scores than False on tied data.""" + # Create data with heavy ties (integer counts) + rng = np.random.default_rng(42) + adata = adata_blobs.copy() + adata.X = rng.poisson(lam=3.0, size=adata.X.shape).astype(np.float32) + rsc.get.anndata_to_GPU(adata) + adata_tc = adata.copy() + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + tie_correct=False, + ) + rsc.tl.rank_genes_groups( + adata_tc, + "blobs", + method="wilcoxon_binned", + use_raw=False, + tie_correct=True, + ) + + # Scores should differ (tie correction adjusts variance) + for group in adata.uns["rank_genes_groups"]["scores"].dtype.names: + scores_no = np.asarray( + adata.uns["rank_genes_groups"]["scores"][group], dtype=float + ) + scores_tc = np.asarray( + adata_tc.uns["rank_genes_groups"]["scores"][group], dtype=float + ) + assert not np.allclose(scores_no, scores_tc, rtol=1e-10), ( + f"Group {group}: tie_correct had no effect on scores" + ) + + def test_use_continuity_changes_scores(self, adata_blobs): + """use_continuity=True should produce different scores than False.""" + adata = adata_blobs.copy() + rsc.get.anndata_to_GPU(adata) + adata_cont = adata.copy() + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + use_continuity=False, + ) + rsc.tl.rank_genes_groups( + adata_cont, + "blobs", + method="wilcoxon_binned", + use_raw=False, + use_continuity=True, + ) + + for group in adata.uns["rank_genes_groups"]["scores"].dtype.names: + scores_no = np.asarray( + adata.uns["rank_genes_groups"]["scores"][group], dtype=float + ) + scores_cont = np.asarray( + adata_cont.uns["rank_genes_groups"]["scores"][group], dtype=float + ) + assert not np.allclose(scores_no, scores_cont, rtol=1e-10), ( + f"Group {group}: use_continuity had no effect on scores" + ) + + @pytest.mark.parametrize("reference", ["rest", "1"]) + def test_tie_correct_with_reference(self, adata_blobs, reference): + """tie_correct works for both OVR and OVO in binned wilcoxon.""" + adata = adata_blobs.copy() + rsc.get.anndata_to_GPU(adata) + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + reference=reference, + tie_correct=True, + ) + + result = adata.uns["rank_genes_groups"] + for group in result["pvals"].dtype.names: + pvals = np.asarray(result["pvals"][group], dtype=float) + assert np.all(pvals >= 0) + assert np.all(pvals <= 1) + assert np.all(np.isfinite(pvals)) + + @pytest.mark.parametrize("reference", ["rest", "1"]) + def test_use_continuity_with_reference(self, adata_blobs, reference): + """use_continuity works for both OVR and OVO in binned wilcoxon.""" + adata = adata_blobs.copy() + rsc.get.anndata_to_GPU(adata) + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + reference=reference, + use_continuity=True, + ) + + result = adata.uns["rank_genes_groups"] + for group in result["pvals"].dtype.names: + pvals = np.asarray(result["pvals"][group], dtype=float) + assert np.all(pvals >= 0) + assert np.all(pvals <= 1) + assert np.all(np.isfinite(pvals)) + + def test_both_corrections_combined(self, adata_blobs): + """tie_correct=True + use_continuity=True together.""" + adata = adata_blobs.copy() + rsc.get.anndata_to_GPU(adata) + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + tie_correct=True, + use_continuity=True, + ) + + result = adata.uns["rank_genes_groups"] + for group in result["pvals"].dtype.names: + pvals = np.asarray(result["pvals"][group], dtype=float) + assert np.all(pvals >= 0) + assert np.all(pvals <= 1) + + def test_both_corrections_with_reference(self, adata_blobs): + """tie_correct + use_continuity with reference mode.""" + adata = adata_blobs.copy() + rsc.get.anndata_to_GPU(adata) + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + reference="1", + tie_correct=True, + use_continuity=True, + ) + + result = adata.uns["rank_genes_groups"] + assert "1" not in result["names"].dtype.names + for group in result["pvals"].dtype.names: + pvals = np.asarray(result["pvals"][group], dtype=float) + assert np.all(pvals >= 0) + assert np.all(pvals <= 1) + + +# ============================================================================ +# pts (percent expressing) for binned wilcoxon +# ============================================================================ + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +def test_binned_pts(adata_blobs, reference): + """pts computation works with wilcoxon_binned.""" + adata = adata_blobs.copy() + rsc.get.anndata_to_GPU(adata) + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + pts=True, + reference=reference, + ) + + result = adata.uns["rank_genes_groups"] + assert "pts" in result + pts = result["pts"] + assert isinstance(pts, __import__("pandas").DataFrame) + assert all(0 <= v <= 1 for col in pts.columns for v in pts[col]) + + if reference == "rest": + assert "pts_rest" in result + + +# ============================================================================ +# mask_var with string key for binned +# ============================================================================ + + +def test_binned_mask_var_string_key(adata_blobs): + """mask_var accepts a string key from adata.var for binned.""" + adata = adata_blobs.copy() + adata.var["selected"] = [True] * 5 + [False] * 5 + rsc.get.anndata_to_GPU(adata) + + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon_binned", + use_raw=False, + mask_var="selected", + ) + + result = adata.uns["rank_genes_groups"] + for group in result["names"].dtype.names: + assert len(result["names"][group]) == 5