From feb3430263b4d3389ef49dd9faa44f31a6f3765e Mon Sep 17 00:00:00 2001 From: R script <1695515+ms609@users.noreply.github.com> Date: Thu, 16 Apr 2026 19:11:20 +0100 Subject: [PATCH 01/11] Merge PR193 and large-tree-migration tip-limit hardening - add split_int alias and migrate serial/batch split counters to split_int/int32 - harden check_ntip with runtime Rcpp::stop() and informative compiled-limit message - add overflow-safe int64 arithmetic in add_ic_element - add serial interrupt checks every 1024 iterations in all seven serial scorers - centralize R-side tip limits via .SL_MAX_TIPS + .CheckMaxTips loaded from cpp_max_tips() - replace hardcoded 32767L guards across distance, transfer, NNI, MAST and consensus paths - add C++ n_tip guards to all batch entry points - add large-tree functional/guard tests and update existing tip-limit tests Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- R/RcppExports.R | 4 + R/transfer_consensus.R | 4 +- R/tree_distance.R | 4 +- R/tree_distance_mast.R | 5 +- R/tree_distance_nni.R | 4 +- R/tree_distance_transfer.R | 4 +- R/tree_distance_utilities.R | 31 ++- R/tree_information.R | 7 +- R/zzz.R | 4 + inst/include/TreeDist/mutual_clustering.h | 10 +- .../include/TreeDist/mutual_clustering_impl.h | 112 ++++---- inst/include/TreeDist/types.h | 3 + src/RcppExports.cpp | 11 + src/ints.h | 1 + src/pairwise_distances.cpp | 262 +++++++++--------- src/tree_distance_functions.cpp | 17 +- src/tree_distances.cpp | 172 ++++++------ src/tree_distances.h | 51 ++-- tests/testthat/test-large-trees.R | 90 ++++++ tests/testthat/test-split_info.R | 4 +- tests/testthat/test-tree_distance_nni.R | 16 +- tests/testthat/test-tree_distance_utilities.R | 26 +- 22 files changed, 505 insertions(+), 337 deletions(-) create mode 100644 tests/testthat/test-large-trees.R diff --git a/R/RcppExports.R b/R/RcppExports.R index 7cd2fcd0..f7aa3d6d 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -275,6 +275,10 @@ cpp_mci_impl_score <- function(x, y, n_tips) { .Call(`_TreeDist_cpp_mci_impl_score`, x, y, n_tips) } +cpp_max_tips <- function() { + .Call(`_TreeDist_cpp_max_tips`) +} + cpp_robinson_foulds_distance <- function(x, y, nTip) { .Call(`_TreeDist_cpp_robinson_foulds_distance`, x, y, nTip) } diff --git a/R/transfer_consensus.R b/R/transfer_consensus.R index 3410cf09..6cef4490 100644 --- a/R/transfer_consensus.R +++ b/R/transfer_consensus.R @@ -66,7 +66,7 @@ TransferConsensus <- function(trees, if (nTip < 4L) { return(StarTree(tipLabels)) } - if (nTip > 32767L) stop("This many tips are not (yet) supported.") + .CheckMaxTips(nTip) # Convert each tree to a raw split matrix (TreeTools C++ internally). # as.Splits() will error if a tree's tips don't match tipLabels. @@ -115,7 +115,7 @@ tc_profile <- function(trees, scale = TRUE, greedy = "best", tipLabels <- TipLabels(trees[[1]]) nTip <- length(tipLabels) if (nTip < 4L) stop("Need at least 4 tips for profiling.") - if (nTip > 32767L) stop("This many tips are not (yet) supported.") + .CheckMaxTips(nTip) splitsList <- lapply(trees, function(tr) unclass(as.Splits(tr, tipLabels))) diff --git a/R/tree_distance.R b/R/tree_distance.R index 11bf9e91..5ea3a9cb 100644 --- a/R/tree_distance.R +++ b/R/tree_distance.R @@ -149,7 +149,7 @@ GeneralizedRF <- function(splits1, splits2, nTip, PairScorer, } nTip <- length(tipLabels) if (nTip < 4) return(NULL) # nocov - if (nTip > 32767L) stop("This many tips are not (yet) supported.") + .CheckMaxTips(nTip) splits_list <- as.Splits(tree1, tipLabels = tipLabels) n_threads <- as.integer(getOption("mc.cores", 1L)) @@ -203,7 +203,7 @@ GeneralizedRF <- function(splits1, splits2, nTip, PairScorer, nTip <- length(tipLabels1) if (nTip < 4) return(NULL) - if (nTip > 32767L) stop("This many tips are not (yet) supported.") + .CheckMaxTips(nTip) splits1 <- as.Splits(tree1, tipLabels = tipLabels1) splits2 <- as.Splits(tree2, tipLabels = tipLabels1) # Use tipLabels1 to ensure order consistency diff --git a/R/tree_distance_mast.R b/R/tree_distance_mast.R index 5c069841..beee13d1 100644 --- a/R/tree_distance_mast.R +++ b/R/tree_distance_mast.R @@ -96,8 +96,9 @@ MASTSize <- function(tree1, tree2 = tree1, rooted = TRUE) { if (nrow(edge1) != nrow(edge2)) { stop("Both trees must contain the same number of edges.") } - if (nTip > 4096L) { - stop("Tree too large; please contact maintainer for advice.") + maxTips <- min(4096L, if (is.null(.SL_MAX_TIPS)) cpp_max_tips() else .SL_MAX_TIPS) + if (nTip > maxTips) { + stop("Trees with > ", maxTips, " tips are not yet supported for MAST.") } cpp_mast(edge1 - 1L, Postorder(edge2) - 1L, nTip) } diff --git a/R/tree_distance_nni.R b/R/tree_distance_nni.R index 45854764..2d617a13 100644 --- a/R/tree_distance_nni.R +++ b/R/tree_distance_nni.R @@ -73,9 +73,7 @@ NNIDist <- function(tree1, tree2 = tree1) { #' @importFrom TreeTools Postorder RenumberTips #' @importFrom ape Nnode.phylo .NNIDistSingle <- function(tree1, tree2, nTip, ...) { - if (nTip > 32768L) { - stop("Cannot calculate NNI distance for trees with so many tips.") - } + .CheckMaxTips(nTip, "NNI") if (nrow(tree1[["edge"]]) != nrow(tree2[["edge"]])) { stop("Both trees must have the same number of edges. ", "Is one rooted and the other unrooted?") diff --git a/R/tree_distance_transfer.R b/R/tree_distance_transfer.R index cd0942fd..55dce2e2 100644 --- a/R/tree_distance_transfer.R +++ b/R/tree_distance_transfer.R @@ -168,7 +168,7 @@ TransferDistSplits <- function(splits1, splits2, if (is.null(tipLabels)) return(NULL) nTip <- length(tipLabels) if (nTip < 4L) return(NULL) - if (nTip > 32767L) stop("This many tips are not (yet) supported.") + .CheckMaxTips(nTip) # Check all trees share same tip set allLabels <- TipLabels(tree1) @@ -211,7 +211,7 @@ TransferDistSplits <- function(splits1, splits2, if (is.null(tipLabels)) return(NULL) nTip <- length(tipLabels) if (nTip < 4L) return(NULL) - if (nTip > 32767L) stop("This many tips are not (yet) supported.") + .CheckMaxTips(nTip) # Check all trees share same tip set allLabels1 <- TipLabels(trees1) diff --git a/R/tree_distance_utilities.R b/R/tree_distance_utilities.R index 01b0ddca..1bae051e 100644 --- a/R/tree_distance_utilities.R +++ b/R/tree_distance_utilities.R @@ -11,15 +11,28 @@ #' @importFrom TreeTools as.Splits TipLabels #' @importFrom utils combn #' @export -# Keep in sync with C++ guard: min(SL_MAX_TIPS, int16_t::max()). -.MaxSupportedTips <- 32767L - -.AssertNtipSupported <- function(nTip) { - if (!is.na(nTip) && nTip > .MaxSupportedTips) { - stop("This many tips are not (yet) supported.") +# Maximum number of tips supported by this compiled package build. +# Set during .onLoad() from `cpp_max_tips()`. +.SL_MAX_TIPS <- NULL + +.CheckMaxTips <- function(nTip, context = "") { + if (!is.na(nTip)) { + maxTips <- .SL_MAX_TIPS + if (is.null(maxTips) || is.na(maxTips)) { + maxTips <- cpp_max_tips() + .SL_MAX_TIPS <<- maxTips + } + if (nTip > maxTips) { + suffix <- if (!nzchar(context)) "." else paste0(" for ", context, ".") + stop("Trees with > ", maxTips, " tips are not yet supported", suffix) + } } + invisible(NULL) } +# Backward-compatible alias for internal callers/tests. +.AssertNtipSupported <- .CheckMaxTips + CalculateTreeDistance <- function(Func, tree1, tree2 = NULL, reportMatching = FALSE, ...) { supportedClasses <- c("phylo", "Splits") @@ -141,7 +154,7 @@ CalculateTreeDistance <- function(Func, tree1, tree2 = NULL, # Fast paths: use OpenMP batch functions when all trees share the same tip # set and no R-level cluster has been configured. Each branch mirrors the # generic path exactly but avoids per-pair R overhead. - .AssertNtipSupported(nTip) + .CheckMaxTips(nTip) if (!is.na(nTip) && is.null(cluster)) { .n_threads <- as.integer(getOption("mc.cores", 1L)) .batch_result <- if (identical(Func, MutualClusteringInfoSplits)) { @@ -242,7 +255,7 @@ CalculateTreeDistance <- function(Func, tree1, tree2 = NULL, #' @importFrom stats setNames .SplitDistanceManyMany <- function(Func, splits1, splits2, tipLabels, nTip = length(tipLabels), ...) { - .AssertNtipSupported(nTip) + .CheckMaxTips(nTip) if (is.na(nTip)) { tipLabels <- union(unlist(tipLabels, use.names = FALSE), unlist(TipLabels(splits2), use.names = FALSE)) @@ -413,7 +426,7 @@ CalculateTreeDistance <- function(Func, tree1, tree2 = NULL, if (ncol(x) != ncol(y)) { stop("Input splits must address same number of tips.") } - .AssertNtipSupported(nTip) + .CheckMaxTips(nTip) } .CheckLabelsSame <- function(labelList) { diff --git a/R/tree_information.R b/R/tree_information.R index 3462f53e..f1d44d97 100644 --- a/R/tree_information.R +++ b/R/tree_information.R @@ -390,9 +390,10 @@ consensus_info <- function(trees, phylo, p) { stop("p must be >= 0.5 in consensus_info()") } nTip <- NTip(trees[[1]]) - # CT_MAX_LEAVES = 16383 in information.h (lookup table size limit) - if (nTip > 16383L) { - stop("This many leaves are not yet supported") + # CT_MAX_LEAVES = 16383 in information.h (lookup-table size limit). + maxTips <- min(16383L, if (is.null(.SL_MAX_TIPS)) cpp_max_tips() else .SL_MAX_TIPS) + if (nTip > maxTips) { + stop("Trees with > ", maxTips, " tips are not yet supported for consensus info.") } .Call(`_TreeDist_consensus_info`, trees, phylo, p) } diff --git a/R/zzz.R b/R/zzz.R index 826528ea..a8e10a2c 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -1,3 +1,7 @@ +.onLoad <- function(libname, pkgname) { + .SL_MAX_TIPS <<- cpp_max_tips() +} + .onUnload <- function(libpath) { StopParallel(quietly = TRUE) library.dynam.unload("TreeDist", libpath) diff --git a/inst/include/TreeDist/mutual_clustering.h b/inst/include/TreeDist/mutual_clustering.h index 437633f0..7d404ed0 100644 --- a/inst/include/TreeDist/mutual_clustering.h +++ b/inst/include/TreeDist/mutual_clustering.h @@ -37,8 +37,8 @@ namespace TreeDist { // Information content of a perfectly-matching split pair. // ic_matching(a, b, n) = (a + b) * lg2[n] - a * lg2[a] - b * lg2[b] - [[nodiscard]] inline double ic_matching(int16 a, int16 b, - int16 n) noexcept { + [[nodiscard]] inline double ic_matching(split_int a, split_int b, + split_int n) noexcept { const double lg2a = lg2[a]; const double lg2b = lg2[b]; const double lg2n = lg2[n]; @@ -77,9 +77,9 @@ namespace TreeDist { // Implementation in mutual_clustering_impl.h. double mutual_clustering_score( - const splitbit* const* a_state, const int16* a_in, int16 a_n_splits, - const splitbit* const* b_state, const int16* b_in, int16 b_n_splits, - int16 n_bins, int32 n_tips, + const splitbit* const* a_state, const split_int* a_in, split_int a_n_splits, + const splitbit* const* b_state, const split_int* b_in, split_int b_n_splits, + split_int n_bins, int32 n_tips, LapScratch& scratch); } // namespace TreeDist diff --git a/inst/include/TreeDist/mutual_clustering_impl.h b/inst/include/TreeDist/mutual_clustering_impl.h index 7f98c8fb..286b55a3 100644 --- a/inst/include/TreeDist/mutual_clustering_impl.h +++ b/inst/include/TreeDist/mutual_clustering_impl.h @@ -68,17 +68,17 @@ void init_lg2_tables(int max_tips) { namespace detail { -static int16 find_exact_matches_raw( - const splitbit* const* a_state, const int16* /*a_in*/, int16 a_n, - const splitbit* const* b_state, const int16* /*b_in*/, int16 b_n, - int16 n_bins, int32 n_tips, - int16* a_match, int16* b_match) +static split_int find_exact_matches_raw( + const splitbit* const* a_state, const split_int* /*a_in*/, split_int a_n, + const splitbit* const* b_state, const split_int* /*b_in*/, split_int b_n, + split_int n_bins, int32 n_tips, + split_int* a_match, split_int* b_match) { - std::fill(a_match, a_match + a_n, int16(0)); - std::fill(b_match, b_match + b_n, int16(0)); + std::fill(a_match, a_match + a_n, split_int(0)); + std::fill(b_match, b_match + b_n, split_int(0)); if (a_n == 0 || b_n == 0) return 0; - const int16 last_bin = n_bins - 1; + const split_int last_bin = n_bins - 1; const splitbit last_mask = (n_tips % SL_BIN_SIZE == 0) ? ~splitbit(0) : (splitbit(1) << (n_tips % SL_BIN_SIZE)) - 1; @@ -87,17 +87,17 @@ static int16 find_exact_matches_raw( std::vector a_canon(static_cast(a_n) * n_bins); std::vector b_canon(static_cast(b_n) * n_bins); - for (int16 i = 0; i < a_n; ++i) { + for (split_int i = 0; i < a_n; ++i) { const bool flip = !(a_state[i][0] & 1); - for (int16 bin = 0; bin < n_bins; ++bin) { + for (split_int bin = 0; bin < n_bins; ++bin) { splitbit val = flip ? ~a_state[i][bin] : a_state[i][bin]; if (bin == last_bin) val &= last_mask; a_canon[i * n_bins + bin] = val; } } - for (int16 i = 0; i < b_n; ++i) { + for (split_int i = 0; i < b_n; ++i) { const bool flip = !(b_state[i][0] & 1); - for (int16 bin = 0; bin < n_bins; ++bin) { + for (split_int bin = 0; bin < n_bins; ++bin) { splitbit val = flip ? ~b_state[i][bin] : b_state[i][bin]; if (bin == last_bin) val &= last_mask; b_canon[i * n_bins + bin] = val; @@ -105,8 +105,8 @@ static int16 find_exact_matches_raw( } // Sort index arrays by canonical form - auto canon_less = [&](const splitbit* canon, int16 n_b, int16 i, int16 j) { - for (int16 bin = 0; bin < n_b; ++bin) { + auto canon_less = [&](const splitbit* canon, split_int n_b, split_int i, split_int j) { + for (split_int bin = 0; bin < n_b; ++bin) { const splitbit vi = canon[i * n_b + bin]; const splitbit vj = canon[j * n_b + bin]; if (vi < vj) return true; @@ -115,28 +115,28 @@ static int16 find_exact_matches_raw( return false; }; - std::vector a_order(a_n), b_order(b_n); - std::iota(a_order.begin(), a_order.end(), int16(0)); - std::iota(b_order.begin(), b_order.end(), int16(0)); + std::vector a_order(a_n), b_order(b_n); + std::iota(a_order.begin(), a_order.end(), split_int(0)); + std::iota(b_order.begin(), b_order.end(), split_int(0)); std::sort(a_order.begin(), a_order.end(), - [&](int16 i, int16 j) { + [&](split_int i, split_int j) { return canon_less(a_canon.data(), n_bins, i, j); }); std::sort(b_order.begin(), b_order.end(), - [&](int16 i, int16 j) { + [&](split_int i, split_int j) { return canon_less(b_canon.data(), n_bins, i, j); }); // Merge-scan - int16 exact_n = 0; - int16 ai_pos = 0, bi_pos = 0; + split_int exact_n = 0; + split_int ai_pos = 0, bi_pos = 0; while (ai_pos < a_n && bi_pos < b_n) { - const int16 ai = a_order[ai_pos]; - const int16 bi = b_order[bi_pos]; + const split_int ai = a_order[ai_pos]; + const split_int bi = b_order[bi_pos]; int cmp = 0; - for (int16 bin = 0; bin < n_bins; ++bin) { + for (split_int bin = 0; bin < n_bins; ++bin) { const splitbit va = a_canon[ai * n_bins + bin]; const splitbit vb = b_canon[bi * n_bins + bin]; if (va < vb) { cmp = -1; break; } @@ -165,14 +165,14 @@ static int16 find_exact_matches_raw( // ---- MCI score implementation ---- double mutual_clustering_score( - const splitbit* const* a_state, const int16* a_in, int16 a_n_splits, - const splitbit* const* b_state, const int16* b_in, int16 b_n_splits, - int16 n_bins, int32 n_tips, + const splitbit* const* a_state, const split_int* a_in, split_int a_n_splits, + const splitbit* const* b_state, const split_int* b_in, split_int b_n_splits, + split_int n_bins, int32 n_tips, LapScratch& scratch) { if (a_n_splits == 0 || b_n_splits == 0 || n_tips == 0) return 0.0; - const int16 most_splits = std::max(a_n_splits, b_n_splits); + const split_int most_splits = std::max(a_n_splits, b_n_splits); const double n_tips_rcp = 1.0 / static_cast(n_tips); constexpr cost max_score = BIG; @@ -181,25 +181,25 @@ double mutual_clustering_score( const double lg2_n = lg2[n_tips]; // --- Phase 1: O(n log n) exact-match detection --- - std::vector a_match_buf(a_n_splits); - std::vector b_match_buf(b_n_splits); + std::vector a_match_buf(a_n_splits); + std::vector b_match_buf(b_n_splits); - const int16 exact_n = detail::find_exact_matches_raw( + const split_int exact_n = detail::find_exact_matches_raw( a_state, a_in, a_n_splits, b_state, b_in, b_n_splits, n_bins, n_tips, a_match_buf.data(), b_match_buf.data()); - const int16* a_match = a_match_buf.data(); - const int16* b_match = b_match_buf.data(); + const split_int* a_match = a_match_buf.data(); + const split_int* b_match = b_match_buf.data(); // Accumulate exact-match score double exact_score = 0.0; - for (int16 ai = 0; ai < a_n_splits; ++ai) { + for (split_int ai = 0; ai < a_n_splits; ++ai) { if (a_match[ai]) { - const int16 na = a_in[ai]; - const int16 nA = static_cast(n_tips - na); - exact_score += ic_matching(na, nA, static_cast(n_tips)); + const split_int na = a_in[ai]; + const split_int nA = static_cast(n_tips - na); + exact_score += ic_matching(na, nA, static_cast(n_tips)); } } @@ -209,46 +209,46 @@ double mutual_clustering_score( } // --- Phase 2: fill cost matrix for unmatched splits only (O(k²)) --- - const int16 lap_n = most_splits - exact_n; + const split_int lap_n = most_splits - exact_n; - std::vector a_unmatch, b_unmatch; + std::vector a_unmatch, b_unmatch; a_unmatch.reserve(lap_n); b_unmatch.reserve(lap_n); - for (int16 ai = 0; ai < a_n_splits; ++ai) { + for (split_int ai = 0; ai < a_n_splits; ++ai) { if (!a_match[ai]) a_unmatch.push_back(ai); } - for (int16 bi = 0; bi < b_n_splits; ++bi) { + for (split_int bi = 0; bi < b_n_splits; ++bi) { if (!b_match[bi]) b_unmatch.push_back(bi); } scratch.score_pool.resize(lap_n); CostMatrix& score = scratch.score_pool; - const int16 a_unmatched_n = static_cast(a_unmatch.size()); - const int16 b_unmatched_n = static_cast(b_unmatch.size()); + const split_int a_unmatched_n = static_cast(a_unmatch.size()); + const split_int b_unmatched_n = static_cast(b_unmatch.size()); - for (int16 a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { - const int16 ai = a_unmatch[a_pos]; - const int16 na = a_in[ai]; - const int16 nA = static_cast(n_tips - na); + for (split_int a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { + const split_int ai = a_unmatch[a_pos]; + const split_int na = a_in[ai]; + const split_int nA = static_cast(n_tips - na); const splitbit* a_row = a_state[ai]; const double offset_a = lg2_n - lg2[na]; const double offset_A = lg2_n - lg2[nA]; - for (int16 b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { - const int16 bi = b_unmatch[b_pos]; + for (split_int b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { + const split_int bi = b_unmatch[b_pos]; const splitbit* b_row = b_state[bi]; - int16 a_and_b = 0; - for (int16 bin = 0; bin < n_bins; ++bin) { + split_int a_and_b = 0; + for (split_int bin = 0; bin < n_bins; ++bin) { a_and_b += TreeTools::count_bits(a_row[bin] & b_row[bin]); } - const int16 nb = b_in[bi]; - const int16 nB = static_cast(n_tips - nb); - const int16 a_and_B = na - a_and_b; - const int16 A_and_b = nb - a_and_b; - const int16 A_and_B = nA - A_and_b; + const split_int nb = b_in[bi]; + const split_int nB = static_cast(n_tips - nb); + const split_int a_and_B = na - a_and_b; + const split_int A_and_b = nb - a_and_b; + const split_int A_and_B = nA - A_and_b; if (a_and_b == A_and_b && a_and_b == a_and_B && a_and_b == A_and_B) { diff --git a/inst/include/TreeDist/types.h b/inst/include/TreeDist/types.h index c953989b..280f218d 100644 --- a/inst/include/TreeDist/types.h +++ b/inst/include/TreeDist/types.h @@ -16,6 +16,9 @@ namespace TreeDist { using int32 = int_fast32_t; using cost = int_fast64_t; + // Canonical type for split/tip/bin counters. + using split_int = int32; + using lap_dim = int; using lap_row = lap_dim; using lap_col = lap_dim; diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index be5b8963..d803056c 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -636,6 +636,16 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// cpp_max_tips +int cpp_max_tips(); +RcppExport SEXP _TreeDist_cpp_max_tips() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(cpp_max_tips()); + return rcpp_result_gen; +END_RCPP +} // cpp_robinson_foulds_distance List cpp_robinson_foulds_distance(const RawMatrix& x, const RawMatrix& y, const IntegerVector& nTip); RcppExport SEXP _TreeDist_cpp_robinson_foulds_distance(SEXP xSEXP, SEXP ySEXP, SEXP nTipSEXP) { @@ -780,6 +790,7 @@ static const R_CallMethodDef CallEntries[] = { {"_TreeDist_cpp_transfer_dist_all_pairs", (DL_FUNC) &_TreeDist_cpp_transfer_dist_all_pairs, 4}, {"_TreeDist_cpp_transfer_dist_cross_pairs", (DL_FUNC) &_TreeDist_cpp_transfer_dist_cross_pairs, 5}, {"_TreeDist_cpp_mci_impl_score", (DL_FUNC) &_TreeDist_cpp_mci_impl_score, 3}, + {"_TreeDist_cpp_max_tips", (DL_FUNC) &_TreeDist_cpp_max_tips, 0}, {"_TreeDist_cpp_robinson_foulds_distance", (DL_FUNC) &_TreeDist_cpp_robinson_foulds_distance, 3}, {"_TreeDist_cpp_robinson_foulds_info", (DL_FUNC) &_TreeDist_cpp_robinson_foulds_info, 3}, {"_TreeDist_cpp_matching_split_distance", (DL_FUNC) &_TreeDist_cpp_matching_split_distance, 3}, diff --git a/src/ints.h b/src/ints.h index 50fd3cc7..3d424c4c 100644 --- a/src/ints.h +++ b/src/ints.h @@ -6,6 +6,7 @@ // Re-export shared types to global scope for backward compatibility. using TreeDist::int16; using TreeDist::int32; +using TreeDist::split_int; // Types used only within TreeDist's own source. using uint32 = uint_fast32_t; diff --git a/src/pairwise_distances.cpp b/src/pairwise_distances.cpp index c66b770f..3bde0444 100644 --- a/src/pairwise_distances.cpp +++ b/src/pairwise_distances.cpp @@ -36,10 +36,10 @@ using TreeTools::count_bits; struct MatchScratch { std::vector a_canon; std::vector b_canon; - std::vector a_order; - std::vector b_order; - std::vector a_match; - std::vector b_match; + std::vector a_order; + std::vector b_order; + std::vector a_match; + std::vector b_match; }; // --------------------------------------------------------------------------- @@ -59,19 +59,19 @@ struct MatchScratch { // a_match[ai] = bi+1 if split ai matched split bi, else 0. // b_match[bi] = ai+1 if split bi matched split ai, else 0. // --------------------------------------------------------------------------- -static int16 find_exact_matches( +static split_int find_exact_matches( const SplitList& a, const SplitList& b, const int32 n_tips, MatchScratch& scratch ) { - const int16 n_bins = a.n_bins; - const int16 last_bin = n_bins - 1; + const split_int n_bins = a.n_bins; + const split_int last_bin = n_bins - 1; const splitbit last_mask = (n_tips % SL_BIN_SIZE == 0) ? ~splitbit(0) : (splitbit(1) << (n_tips % SL_BIN_SIZE)) - 1; - const int16 a_n = a.n_splits; - const int16 b_n = b.n_splits; + const split_int a_n = a.n_splits; + const split_int b_n = b.n_splits; // Ensure buffers are large enough (grow lazily, never shrink) const size_t a_canon_sz = static_cast(a_n) * n_bins; @@ -83,10 +83,10 @@ static int16 find_exact_matches( if (scratch.a_match.size() < static_cast(a_n)) scratch.a_match.resize(a_n); if (scratch.b_match.size() < static_cast(b_n)) scratch.b_match.resize(b_n); - int16* a_match = scratch.a_match.data(); - int16* b_match = scratch.b_match.data(); - std::fill(a_match, a_match + a_n, int16(0)); - std::fill(b_match, b_match + b_n, int16(0)); + split_int* a_match = scratch.a_match.data(); + split_int* b_match = scratch.b_match.data(); + std::fill(a_match, a_match + a_n, split_int(0)); + std::fill(b_match, b_match + b_n, split_int(0)); if (a_n == 0 || b_n == 0) return 0; @@ -94,17 +94,17 @@ static int16 find_exact_matches( splitbit* b_canon = scratch.b_canon.data(); // --- 1. Compute canonical forms into flat buffers --- - for (int16 i = 0; i < a_n; ++i) { + for (split_int i = 0; i < a_n; ++i) { const bool flip = !(a.state[i][0] & 1); - for (int16 bin = 0; bin < n_bins; ++bin) { + for (split_int bin = 0; bin < n_bins; ++bin) { splitbit val = flip ? ~a.state[i][bin] : a.state[i][bin]; if (bin == last_bin) val &= last_mask; a_canon[i * n_bins + bin] = val; } } - for (int16 i = 0; i < b_n; ++i) { + for (split_int i = 0; i < b_n; ++i) { const bool flip = !(b.state[i][0] & 1); - for (int16 bin = 0; bin < n_bins; ++bin) { + for (split_int bin = 0; bin < n_bins; ++bin) { splitbit val = flip ? ~b.state[i][bin] : b.state[i][bin]; if (bin == last_bin) val &= last_mask; b_canon[i * n_bins + bin] = val; @@ -112,8 +112,8 @@ static int16 find_exact_matches( } // --- 2. Sort index arrays by canonical form --- - auto canon_less = [&](const splitbit* canon, int16 i, int16 j) { - for (int16 bin = 0; bin < n_bins; ++bin) { + auto canon_less = [&](const splitbit* canon, split_int i, split_int j) { + for (split_int bin = 0; bin < n_bins; ++bin) { const splitbit vi = canon[i * n_bins + bin]; const splitbit vj = canon[j * n_bins + bin]; if (vi < vj) return true; @@ -122,29 +122,29 @@ static int16 find_exact_matches( return false; // #nocov }; - int16* a_order = scratch.a_order.data(); - int16* b_order = scratch.b_order.data(); - std::iota(a_order, a_order + a_n, int16(0)); - std::iota(b_order, b_order + b_n, int16(0)); + split_int* a_order = scratch.a_order.data(); + split_int* b_order = scratch.b_order.data(); + std::iota(a_order, a_order + a_n, split_int(0)); + std::iota(b_order, b_order + b_n, split_int(0)); std::sort(a_order, a_order + a_n, - [&](int16 i, int16 j) { + [&](split_int i, split_int j) { return canon_less(a_canon, i, j); }); std::sort(b_order, b_order + b_n, - [&](int16 i, int16 j) { + [&](split_int i, split_int j) { return canon_less(b_canon, i, j); }); // --- 3. Merge-scan to find matches --- - int16 exact_n = 0; - int16 ai_pos = 0, bi_pos = 0; + split_int exact_n = 0; + split_int ai_pos = 0, bi_pos = 0; while (ai_pos < a_n && bi_pos < b_n) { - const int16 ai = a_order[ai_pos]; - const int16 bi = b_order[bi_pos]; + const split_int ai = a_order[ai_pos]; + const split_int bi = b_order[bi_pos]; int cmp = 0; - for (int16 bin = 0; bin < n_bins; ++bin) { + for (split_int bin = 0; bin < n_bins; ++bin) { const splitbit va = a_canon[ai * n_bins + bin]; const splitbit vb = b_canon[bi * n_bins + bin]; if (va < vb) { cmp = -1; break; } @@ -178,7 +178,7 @@ static double mutual_clustering_score( ) { if (a.n_splits == 0 || b.n_splits == 0 || n_tips == 0) return 0.0; - const int16 most_splits = std::max(a.n_splits, b.n_splits); + const split_int most_splits = std::max(a.n_splits, b.n_splits); const double n_tips_rcp = 1.0 / static_cast(n_tips); constexpr cost max_score = BIG; @@ -187,16 +187,16 @@ static double mutual_clustering_score( const double lg2_n = lg2[n_tips]; // --- Phase 1: O(n log n) exact-match detection --- - const int16 exact_n = find_exact_matches(a, b, n_tips, mscratch); - const int16* a_match = mscratch.a_match.data(); - const int16* b_match = mscratch.b_match.data(); + const split_int exact_n = find_exact_matches(a, b, n_tips, mscratch); + const split_int* a_match = mscratch.a_match.data(); + const split_int* b_match = mscratch.b_match.data(); // Accumulate exact-match score double exact_score = 0.0; - for (int16 ai = 0; ai < a.n_splits; ++ai) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { if (a_match[ai]) { - const int16 na = a.in_split[ai]; - const int16 nA = n_tips - na; + const split_int na = a.in_split[ai]; + const split_int nA = n_tips - na; exact_score += TreeDist::ic_matching(na, nA, n_tips); } } @@ -207,47 +207,47 @@ static double mutual_clustering_score( } // --- Phase 2: fill cost matrix for unmatched splits only (O(k²)) --- - const int16 lap_n = most_splits - exact_n; + const split_int lap_n = most_splits - exact_n; // Build index maps for unmatched splits - std::vector a_unmatch, b_unmatch; + std::vector a_unmatch, b_unmatch; a_unmatch.reserve(lap_n); b_unmatch.reserve(lap_n); - for (int16 ai = 0; ai < a.n_splits; ++ai) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { if (!a_match[ai]) a_unmatch.push_back(ai); } - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { if (!b_match[bi]) b_unmatch.push_back(bi); } scratch.score_pool.resize(lap_n); cost_matrix& score = scratch.score_pool; - const int16 a_unmatched_n = static_cast(a_unmatch.size()); - const int16 b_unmatched_n = static_cast(b_unmatch.size()); + const split_int a_unmatched_n = static_cast(a_unmatch.size()); + const split_int b_unmatched_n = static_cast(b_unmatch.size()); - for (int16 a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { - const int16 ai = a_unmatch[a_pos]; - const int16 na = a.in_split[ai]; - const int16 nA = n_tips - na; + for (split_int a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { + const split_int ai = a_unmatch[a_pos]; + const split_int na = a.in_split[ai]; + const split_int nA = n_tips - na; const auto* a_row = a.state[ai]; const double offset_a = lg2_n - lg2[na]; const double offset_A = lg2_n - lg2[nA]; - for (int16 b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { - const int16 bi = b_unmatch[b_pos]; + for (split_int b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { + const split_int bi = b_unmatch[b_pos]; const auto* b_row = b.state[bi]; - int16 a_and_b = 0; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + split_int a_and_b = 0; + for (split_int bin = 0; bin < a.n_bins; ++bin) { a_and_b += count_bits(a_row[bin] & b_row[bin]); } - const int16 nb = b.in_split[bi]; - const int16 nB = n_tips - nb; - const int16 a_and_B = na - a_and_b; - const int16 A_and_b = nb - a_and_b; - const int16 A_and_B = nA - A_and_b; + const split_int nb = b.in_split[bi]; + const split_int nB = n_tips - nb; + const split_int a_and_B = na - a_and_b; + const split_int A_and_b = nb - a_and_b; + const split_int A_and_B = nA - A_and_b; if (a_and_b == A_and_b && a_and_b == a_and_B && a_and_b == A_and_B) { score(a_pos, b_pos) = max_score; @@ -305,6 +305,7 @@ NumericVector cpp_mutual_clustering_all_pairs( const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); if (N < 2) return NumericVector(0); @@ -365,22 +366,22 @@ static double rf_info_score( const SplitList& a, const SplitList& b, const int32 n_tips, MatchScratch& mscratch ) { - const int16 a_n = a.n_splits; - const int16 b_n = b.n_splits; + const split_int a_n = a.n_splits; + const split_int b_n = b.n_splits; if (a_n == 0 || b_n == 0) return 0; // Use sort+merge to find exact matches in O(n log n) - const int16 exact_n = find_exact_matches(a, b, n_tips, mscratch); + const split_int exact_n = find_exact_matches(a, b, n_tips, mscratch); if (exact_n == 0) return 0; // Sum info contribution for each matched split in a - const int16* a_match = mscratch.a_match.data(); + const split_int* a_match = mscratch.a_match.data(); const double lg2_unrooted_n = lg2_unrooted[n_tips]; double score = 0; - for (int16 ai = 0; ai < a_n; ++ai) { + for (split_int ai = 0; ai < a_n; ++ai) { if (a_match[ai] == 0) continue; - int16 leaves_in_split = 0; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + split_int leaves_in_split = 0; + for (split_int bin = 0; bin < a.n_bins; ++bin) { leaves_in_split += count_bits(a.state[ai][bin]); } score += lg2_unrooted_n @@ -397,6 +398,7 @@ NumericVector cpp_rf_info_all_pairs( const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); if (N < 2) return NumericVector(0); const int n_pairs = N * (N - 1) / 2; @@ -444,48 +446,48 @@ static double msd_score( const SplitList& a, const SplitList& b, const int32 n_tips, LapScratch& scratch, MatchScratch& mscratch ) { - const int16 most_splits = std::max(a.n_splits, b.n_splits); + const split_int most_splits = std::max(a.n_splits, b.n_splits); if (most_splits == 0) return 0.0; const bool a_has_more = (a.n_splits > b.n_splits); - const int16 a_extra = a_has_more ? most_splits - b.n_splits : 0; - const int16 b_extra = a_has_more ? 0 : most_splits - a.n_splits; - const int16 half_tips = n_tips / 2; + const split_int a_extra = a_has_more ? most_splits - b.n_splits : 0; + const split_int b_extra = a_has_more ? 0 : most_splits - a.n_splits; + const split_int half_tips = n_tips / 2; const cost max_score = BIG / most_splits; // --- Phase 1: O(n log n) exact-match detection --- - const int16 exact_n = find_exact_matches(a, b, n_tips, mscratch); - const int16* a_match = mscratch.a_match.data(); - const int16* b_match = mscratch.b_match.data(); + const split_int exact_n = find_exact_matches(a, b, n_tips, mscratch); + const split_int* a_match = mscratch.a_match.data(); + const split_int* b_match = mscratch.b_match.data(); if (exact_n == b.n_splits || exact_n == a.n_splits) { return 0.0; } // --- Phase 2: fill cost matrix for unmatched splits only --- - const int16 lap_n = most_splits - exact_n; + const split_int lap_n = most_splits - exact_n; - std::vector a_unmatch, b_unmatch; + std::vector a_unmatch, b_unmatch; a_unmatch.reserve(lap_n); b_unmatch.reserve(lap_n); - for (int16 ai = 0; ai < a.n_splits; ++ai) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { if (!a_match[ai]) a_unmatch.push_back(ai); } - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { if (!b_match[bi]) b_unmatch.push_back(bi); } scratch.score_pool.resize(lap_n); cost_matrix& score = scratch.score_pool; - const int16 a_unmatched_n = static_cast(a_unmatch.size()); - const int16 b_unmatched_n = static_cast(b_unmatch.size()); + const split_int a_unmatched_n = static_cast(a_unmatch.size()); + const split_int b_unmatched_n = static_cast(b_unmatch.size()); - for (int16 a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { - const int16 ai = a_unmatch[a_pos]; - for (int16 b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { - const int16 bi = b_unmatch[b_pos]; + for (split_int a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { + const split_int ai = a_unmatch[a_pos]; + for (split_int b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { + const split_int bi = b_unmatch[b_pos]; splitbit total = 0; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + for (split_int bin = 0; bin < a.n_bins; ++bin) { total += count_bits(a.state[ai][bin] ^ b.state[bi][bin]); } score(a_pos, b_pos) = static_cast( @@ -516,6 +518,7 @@ NumericVector cpp_msd_all_pairs( const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); if (N < 2) return NumericVector(0); const int n_pairs = N * (N - 1) / 2; @@ -571,30 +574,30 @@ static double msi_score( const SplitList& a, const SplitList& b, const int32 n_tips, LapScratch& scratch ) { - const int16 most_splits = std::max(a.n_splits, b.n_splits); + const split_int most_splits = std::max(a.n_splits, b.n_splits); if (most_splits == 0) return 0.0; constexpr cost max_score = BIG; const double max_possible = lg2_unrooted[n_tips] - - lg2_rooted[int16((n_tips + 1) / 2)] - - lg2_rooted[int16(n_tips / 2)]; + - lg2_rooted[split_int((n_tips + 1) / 2)] + - lg2_rooted[split_int(n_tips / 2)]; const double score_over_possible = static_cast(max_score) / max_possible; const double possible_over_score = max_possible / static_cast(max_score); scratch.score_pool.resize(most_splits); cost_matrix& score = scratch.score_pool; - for (int16 ai = 0; ai < a.n_splits; ++ai) { - for (int16 bi = 0; bi < b.n_splits; ++bi) { - int16 n_a_only = 0, n_a_and_b = 0, n_different = 0; + for (split_int ai = 0; ai < a.n_splits; ++ai) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { + split_int n_a_only = 0, n_a_and_b = 0, n_different = 0; splitbit different; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + for (split_int bin = 0; bin < a.n_bins; ++bin) { different = a.state[ai][bin] ^ b.state[bi][bin]; n_different += count_bits(different); n_a_only += count_bits(a.state[ai][bin] & different); n_a_and_b += count_bits(a.state[ai][bin] & ~different); } - const int16 n_same = n_tips - n_different; + const split_int n_same = n_tips - n_different; score(ai, bi) = cost(max_score - score_over_possible * TreeDist::mmsi_score(n_same, n_a_and_b, n_different, n_a_only)); } @@ -618,6 +621,7 @@ NumericVector cpp_msi_all_pairs( const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); if (N < 2) return NumericVector(0); const int n_pairs = N * (N - 1) / 2; @@ -669,10 +673,10 @@ static double shared_phylo_score( const SplitList& a, const SplitList& b, const int32 n_tips, LapScratch& scratch ) { - const int16 most_splits = std::max(a.n_splits, b.n_splits); + const split_int most_splits = std::max(a.n_splits, b.n_splits); if (most_splits == 0) return 0.0; - const int16 overlap_a = int16(n_tips + 1) / 2; + const split_int overlap_a = split_int(n_tips + 1) / 2; constexpr cost max_score = BIG; const double best_overlap = TreeDist::one_overlap(overlap_a, n_tips / 2, n_tips); const double max_possible = lg2_unrooted[n_tips] - best_overlap; @@ -682,8 +686,8 @@ static double shared_phylo_score( scratch.score_pool.resize(most_splits); cost_matrix& score = scratch.score_pool; - for (int16 ai = 0; ai < a.n_splits; ++ai) { - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { const double spi = TreeDist::spi_overlap( a.state[ai], b.state[bi], n_tips, a.in_split[ai], b.in_split[bi], a.n_bins); @@ -710,6 +714,7 @@ NumericVector cpp_shared_phylo_all_pairs( const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); if (N < 2) return NumericVector(0); const int n_pairs = N * (N - 1) / 2; @@ -758,7 +763,7 @@ static double jaccard_score( const double exponent, const bool allow_conflict, LapScratch& scratch, MatchScratch& mscratch ) { - const int16 most_splits = std::max(a.n_splits, b.n_splits); + const split_int most_splits = std::max(a.n_splits, b.n_splits); if (most_splits == 0) return 0.0; constexpr cost max_score = BIG; @@ -767,7 +772,7 @@ static double jaccard_score( // --- Phase 1: O(n log n) exact-match detection --- // Only used when allow_conflict=true; otherwise the full LAP may reassign // non-matching splits to compatible (non-exact) partners. - int16 exact_n = 0; + split_int exact_n = 0; if (allow_conflict) { exact_n = find_exact_matches(a, b, n_tips, mscratch); } else { @@ -776,61 +781,61 @@ static double jaccard_score( mscratch.a_match.resize(a.n_splits); if (mscratch.b_match.size() < static_cast(b.n_splits)) mscratch.b_match.resize(b.n_splits); - std::fill(mscratch.a_match.data(), mscratch.a_match.data() + a.n_splits, int16(0)); - std::fill(mscratch.b_match.data(), mscratch.b_match.data() + b.n_splits, int16(0)); + std::fill(mscratch.a_match.data(), mscratch.a_match.data() + a.n_splits, split_int(0)); + std::fill(mscratch.b_match.data(), mscratch.b_match.data() + b.n_splits, split_int(0)); } - const int16* a_match = mscratch.a_match.data(); - const int16* b_match = mscratch.b_match.data(); + const split_int* a_match = mscratch.a_match.data(); + const split_int* b_match = mscratch.b_match.data(); if (exact_n == b.n_splits || exact_n == a.n_splits) { return static_cast(exact_n); } // --- Phase 2: fill cost matrix for unmatched splits only --- - const int16 lap_n = most_splits - exact_n; + const split_int lap_n = most_splits - exact_n; - std::vector a_unmatch, b_unmatch; + std::vector a_unmatch, b_unmatch; a_unmatch.reserve(lap_n); b_unmatch.reserve(lap_n); - for (int16 ai = 0; ai < a.n_splits; ++ai) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { if (!a_match[ai]) a_unmatch.push_back(ai); } - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { if (!b_match[bi]) b_unmatch.push_back(bi); } scratch.score_pool.resize(lap_n); cost_matrix& score = scratch.score_pool; - const int16 a_unmatched_n = static_cast(a_unmatch.size()); - const int16 b_unmatched_n = static_cast(b_unmatch.size()); + const split_int a_unmatched_n = static_cast(a_unmatch.size()); + const split_int b_unmatched_n = static_cast(b_unmatch.size()); - for (int16 a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { - const int16 ai = a_unmatch[a_pos]; - const int16 na = a.in_split[ai]; - const int16 nA = n_tips - na; + for (split_int a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { + const split_int ai = a_unmatch[a_pos]; + const split_int na = a.in_split[ai]; + const split_int nA = n_tips - na; - for (int16 b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { - const int16 bi = b_unmatch[b_pos]; - int16 a_and_b = 0; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + for (split_int b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { + const split_int bi = b_unmatch[b_pos]; + split_int a_and_b = 0; + for (split_int bin = 0; bin < a.n_bins; ++bin) { a_and_b += count_bits(a.state[ai][bin] & b.state[bi][bin]); } - const int16 nb = b.in_split[bi]; - const int16 nB = n_tips - nb; - const int16 a_and_B = na - a_and_b; - const int16 A_and_b = nb - a_and_b; - const int16 A_and_B = nB - a_and_B; + const split_int nb = b.in_split[bi]; + const split_int nB = n_tips - nb; + const split_int a_and_B = na - a_and_b; + const split_int A_and_b = nb - a_and_b; + const split_int A_and_B = nB - a_and_B; if (!allow_conflict && !( a_and_b == na || a_and_B == na || A_and_b == nA || A_and_B == nA)) { score(a_pos, b_pos) = max_score; } else { - const int16 A_or_b = n_tips - a_and_B; - const int16 a_or_B = n_tips - A_and_b; - const int16 a_or_b = n_tips - A_and_B; - const int16 A_or_B = n_tips - a_and_b; + const split_int A_or_b = n_tips - a_and_B; + const split_int a_or_B = n_tips - A_and_b; + const split_int a_or_b = n_tips - A_and_B; + const split_int A_or_B = n_tips - a_and_b; const double ars_ab = static_cast(a_and_b) / static_cast(a_or_b); const double ars_Ab = static_cast(A_and_b) / static_cast(A_or_b); const double ars_aB = static_cast(a_and_B) / static_cast(a_or_B); @@ -875,6 +880,7 @@ NumericVector cpp_jaccard_all_pairs( const bool allow_conflict = true, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); if (N < 2) return NumericVector(0); const int n_pairs = N * (N - 1) / 2; @@ -944,6 +950,7 @@ NumericMatrix cpp_mutual_clustering_cross_pairs( const List& splits_a, const List& splits_b, const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int nA = splits_a.size(); const int nB = splits_b.size(); if (nA == 0 || nB == 0) return NumericMatrix(nA, nB); @@ -987,6 +994,7 @@ NumericMatrix cpp_rf_info_cross_pairs( const List& splits_a, const List& splits_b, const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int nA = splits_a.size(); const int nB = splits_b.size(); if (nA == 0 || nB == 0) return NumericMatrix(nA, nB); @@ -1027,6 +1035,7 @@ NumericMatrix cpp_msd_cross_pairs( const List& splits_a, const List& splits_b, const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int nA = splits_a.size(); const int nB = splits_b.size(); if (nA == 0 || nB == 0) return NumericMatrix(nA, nB); @@ -1070,6 +1079,7 @@ NumericMatrix cpp_msi_cross_pairs( const List& splits_a, const List& splits_b, const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int nA = splits_a.size(); const int nB = splits_b.size(); if (nA == 0 || nB == 0) return NumericMatrix(nA, nB); @@ -1110,6 +1120,7 @@ NumericMatrix cpp_shared_phylo_cross_pairs( const List& splits_a, const List& splits_b, const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int nA = splits_a.size(); const int nB = splits_b.size(); if (nA == 0 || nB == 0) return NumericMatrix(nA, nB); @@ -1153,6 +1164,7 @@ NumericMatrix cpp_jaccard_cross_pairs( const bool allow_conflict = true, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int nA = splits_a.size(); const int nB = splits_b.size(); if (nA == 0 || nB == 0) return NumericMatrix(nA, nB); @@ -1206,6 +1218,7 @@ NumericVector cpp_clustering_entropy_batch( const List& splits_list, const int n_tip ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); NumericVector result(N); if (N == 0 || n_tip <= 0) return result; @@ -1216,7 +1229,7 @@ NumericVector cpp_clustering_entropy_batch( for (int i = 0; i < N; ++i) { SplitList sl(Rcpp::as(splits_list[i])); double total = 0.0; - for (int16 s = 0; s < sl.n_splits; ++s) { + for (split_int s = 0; s < sl.n_splits; ++s) { const int k = sl.in_split[s]; if (k <= 0 || k >= n_tip) continue; const double p = k * invN; @@ -1239,6 +1252,7 @@ NumericVector cpp_splitwise_info_batch( const List& splits_list, const int n_tip ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); NumericVector result(N); if (N == 0 || n_tip < 4) return result; @@ -1257,7 +1271,7 @@ NumericVector cpp_splitwise_info_batch( for (int i = 0; i < N; ++i) { SplitList sl(Rcpp::as(splits_list[i])); double total = 0.0; - for (int16 s = 0; s < sl.n_splits; ++s) { + for (split_int s = 0; s < sl.n_splits; ++s) { const int k = sl.in_split[s]; if (k < 2 || (n_tip - k) < 2) continue; total += l2u_n - l2r[k] - l2r[n_tip - k]; diff --git a/src/tree_distance_functions.cpp b/src/tree_distance_functions.cpp index 99347300..810cc1cb 100644 --- a/src/tree_distance_functions.cpp +++ b/src/tree_distance_functions.cpp @@ -28,15 +28,15 @@ double cpp_mci_impl_score(const Rcpp::RawMatrix& x, // Build arrays matching the header's raw-pointer API types. std::vector a_ptrs(a.n_splits); std::vector b_ptrs(b.n_splits); - std::vector a_in(a.n_splits); - std::vector b_in(b.n_splits); - for (TreeDist::int16 i = 0; i < a.n_splits; ++i) { + std::vector a_in(a.n_splits); + std::vector b_in(b.n_splits); + for (TreeDist::int32 i = 0; i < a.n_splits; ++i) { a_ptrs[i] = a.state[i]; - a_in[i] = static_cast(a.in_split[i]); + a_in[i] = static_cast(a.in_split[i]); } - for (TreeDist::int16 i = 0; i < b.n_splits; ++i) { + for (TreeDist::int32 i = 0; i < b.n_splits; ++i) { b_ptrs[i] = b.state[i]; - b_in[i] = static_cast(b.in_split[i]); + b_in[i] = static_cast(b.in_split[i]); } return TreeDist::mutual_clustering_score( @@ -45,3 +45,8 @@ double cpp_mci_impl_score(const Rcpp::RawMatrix& x, a.n_bins, static_cast(n_tips), scratch); } + +// [[Rcpp::export]] +int cpp_max_tips() { + return static_cast(SL_MAX_TIPS); +} diff --git a/src/tree_distances.cpp b/src/tree_distances.cpp index f79b48da..1ced9e26 100644 --- a/src/tree_distances.cpp +++ b/src/tree_distances.cpp @@ -29,16 +29,21 @@ namespace TreeDist { } } - void check_ntip(const double n) { - // SplitList dimensions are bounded by SL_MAX_TIPS, and current scoring - // paths use int16-sized counts internally. + void check_ntip(const int32 n) { static_assert(SL_MAX_TIPS <= std::numeric_limits::max(), "SL_MAX_TIPS must fit in int32"); - constexpr int32 max_supported_tips = std::min( - int32(SL_MAX_TIPS), int32(std::numeric_limits::max()) - ); + constexpr int32 compiled_tip_limit = static_cast(SL_MAX_TIPS); + constexpr int64_t split_int_limit = + static_cast((std::numeric_limits::max)()); + constexpr int64_t max_supported_tips = + std::min(compiled_tip_limit, split_int_limit); + if (n > max_supported_tips) { - Rcpp::stop("This many tips are not (yet) supported."); + Rcpp::stop( + "Requested nTip = %d exceeds this TreeDist build limit (%d; " + "compiled SL_MAX_TIPS = %d): this many tips are not yet supported.", + n, static_cast(max_supported_tips), compiled_tip_limit + ); } } @@ -69,6 +74,7 @@ inline List robinson_foulds_distance(const RawMatrix &x, const RawMatrix &y, } for (int32 ai = a.n_splits; ai--; ) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); for (int32 bi = b.n_splits; bi--; ) { bool all_match = true; @@ -104,8 +110,8 @@ inline List robinson_foulds_info(const RawMatrix &x, const RawMatrix &y, const int32 n_tips) { const SplitList a(x), b(y); - const int16 last_bin = a.n_bins - 1; - const int16 unset_tips = (n_tips % SL_BIN_SIZE) ? + const split_int last_bin = a.n_bins - 1; + const split_int unset_tips = (n_tips % SL_BIN_SIZE) ? SL_BIN_SIZE - n_tips % SL_BIN_SIZE : 0; const splitbit unset_mask = ALL_ONES >> unset_tips; @@ -116,26 +122,27 @@ inline List robinson_foulds_info(const RawMatrix &x, const RawMatrix &y, // Heap-backed scratch avoids large fixed-size stack allocation. std::vector b_complement(size_t(b.n_splits) * size_t(a.n_bins)); - for (int16 i = 0; i < b.n_splits; i++) { - for (int16 bin = 0; bin < last_bin; ++bin) { + for (split_int i = 0; i < b.n_splits; i++) { + for (split_int bin = 0; bin < last_bin; ++bin) { b_complement[size_t(i) * a.n_bins + bin] = ~b.state[i][bin]; } b_complement[size_t(i) * a.n_bins + last_bin] = b.state[i][last_bin] ^ unset_mask; } - for (int16 ai = 0; ai < a.n_splits; ++ai) { - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); + for (split_int bi = 0; bi < b.n_splits; ++bi) { bool all_match = true, all_complement = true; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + for (split_int bin = 0; bin < a.n_bins; ++bin) { if ((a.state[ai][bin] != b.state[bi][bin])) { all_match = false; break; } } if (!all_match) { - for (int16 bin = 0; bin < a.n_bins; ++bin) { + for (split_int bin = 0; bin < a.n_bins; ++bin) { if ((a.state[ai][bin] != b_complement[size_t(bi) * a.n_bins + bin])) { all_complement = false; break; @@ -143,8 +150,8 @@ inline List robinson_foulds_info(const RawMatrix &x, const RawMatrix &y, } } if (all_match || all_complement) { - int16 leaves_in_split = 0; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + split_int leaves_in_split = 0; + for (split_int bin = 0; bin < a.n_bins; ++bin) { leaves_in_split += count_bits(a.state[ai][bin]); } @@ -167,27 +174,28 @@ inline List robinson_foulds_info(const RawMatrix &x, const RawMatrix &y, inline List matching_split_distance(const RawMatrix &x, const RawMatrix &y, const int32 n_tips) { const SplitList a(x), b(y); - const int16 most_splits = std::max(a.n_splits, b.n_splits); - const int16 split_diff = most_splits - std::min(a.n_splits, b.n_splits); - const int16 half_tips = n_tips / 2; + const split_int most_splits = std::max(a.n_splits, b.n_splits); + const split_int split_diff = most_splits - std::min(a.n_splits, b.n_splits); + const split_int half_tips = n_tips / 2; if (most_splits == 0) { return List::create(Named("score") = 0); } const cost max_score = BIG / most_splits; cost_matrix score(most_splits); - for (int16 ai = 0; ai < a.n_splits; ++ai) { - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); + for (split_int bi = 0; bi < b.n_splits; ++bi) { splitbit total = 0; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + for (split_int bin = 0; bin < a.n_bins; ++bin) { total += count_bits(a.state[ai][bin] ^ b.state[bi][bin]); } score(ai, bi) = total; } } - for (int16 ai = 0; ai < a.n_splits; ++ai) { - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { if (score(ai, bi) > half_tips) { score(ai, bi) = n_tips - score(ai, bi); } @@ -209,7 +217,7 @@ inline List matching_split_distance(const RawMatrix &x, const RawMatrix &y, std::vector final_matching; final_matching.reserve(a.n_splits); - for (int16 i = 0; i < a.n_splits; ++i) { + for (split_int i = 0; i < a.n_splits; ++i) { const int match = (rowsol[i] < b.n_splits) ? static_cast(rowsol[i]) + 1 : NA_INTEGER; @@ -225,7 +233,7 @@ inline List jaccard_similarity(const RawMatrix &x, const RawMatrix &y, const int32 n_tips, const NumericVector &k, const LogicalVector &allowConflict) { const SplitList a(x), b(y); - const int16 most_splits = std::max(a.n_splits, b.n_splits); + const split_int most_splits = std::max(a.n_splits, b.n_splits); constexpr cost max_score = BIG; constexpr double max_scoreL = max_score; @@ -235,20 +243,21 @@ inline List jaccard_similarity(const RawMatrix &x, const RawMatrix &y, cost_matrix score(most_splits); - for (int16 ai = 0; ai < a.n_splits; ++ai) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); - const int16 na = a.in_split[ai]; - const int16 nA = n_tips - na; + const split_int na = a.in_split[ai]; + const split_int nA = n_tips - na; - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { // x divides tips into a|A; y divides tips into b|B - int16 a_and_b = 0; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + split_int a_and_b = 0; + for (split_int bin = 0; bin < a.n_bins; ++bin) { a_and_b += count_bits(a.state[ai][bin] & b.state[bi][bin]); } - const int16 + const split_int nb = b.in_split[bi], nB = n_tips - nb, a_and_B = na - a_and_b, @@ -266,7 +275,7 @@ inline List jaccard_similarity(const RawMatrix &x, const RawMatrix &y, score(ai, bi) = max_score; /* Prohibited */ } else { - const int16 + const split_int A_or_b = n_tips - a_and_B, a_or_B = n_tips - A_and_b, a_or_b = n_tips - A_and_B, @@ -315,7 +324,7 @@ inline List jaccard_similarity(const RawMatrix &x, const RawMatrix &y, std::vector final_matching; final_matching.reserve(a.n_splits); - for (int16 i = 0; i < a.n_splits; ++i) { + for (split_int i = 0; i < a.n_splits; ++i) { const int match = (rowsol[i] < b.n_splits) ? static_cast(rowsol[i]) + 1 : NA_INTEGER; @@ -330,31 +339,32 @@ inline List jaccard_similarity(const RawMatrix &x, const RawMatrix &y, List msi_distance(const RawMatrix &x, const RawMatrix &y, const int32 n_tips) { const SplitList a(x), b(y); - const int16 most_splits = std::max(a.n_splits, b.n_splits); + const split_int most_splits = std::max(a.n_splits, b.n_splits); constexpr cost max_score = BIG; const double max_possible = lg2_unrooted[n_tips] - - lg2_rooted[int16((n_tips + 1) / 2)] - lg2_rooted[int16(n_tips / 2)]; + lg2_rooted[split_int((n_tips + 1) / 2)] - lg2_rooted[split_int(n_tips / 2)]; const double score_over_possible = static_cast(max_score) / max_possible; const double possible_over_score = max_possible / max_score; cost_matrix score(most_splits); - splitbit different[SL_MAX_BINS]; + std::vector different(a.n_bins); - for (int16 ai = 0; ai < a.n_splits; ++ai) { - for (int16 bi = 0; bi < b.n_splits; ++bi) { - int16 + for (split_int ai = 0; ai < a.n_splits; ++ai) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); + for (split_int bi = 0; bi < b.n_splits; ++bi) { + split_int n_different = 0, n_a_only = 0, n_a_and_b = 0 ; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + for (split_int bin = 0; bin < a.n_bins; ++bin) { different[bin] = a.state[ai][bin] ^ b.state[bi][bin]; n_different += count_bits(different[bin]); n_a_only += count_bits(a.state[ai][bin] & different[bin]); n_a_and_b += count_bits(a.state[ai][bin] & ~different[bin]); } - const int16 n_same = n_tips - n_different; + const split_int n_same = n_tips - n_different; score(ai, bi) = cost(max_score - (score_over_possible * @@ -377,7 +387,7 @@ List msi_distance(const RawMatrix &x, const RawMatrix &y, const int32 n_tips) { std::vector final_matching; final_matching.reserve(a.n_splits); - for (int16 i = 0; i < a.n_splits; ++i) { + for (split_int i = 0; i < a.n_splits; ++i) { const int match = (rowsol[i] < b.n_splits) ? static_cast(rowsol[i]) + 1 : NA_INTEGER; @@ -394,9 +404,9 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, const SplitList a(x); const SplitList b(y); const bool a_has_more_splits = (a.n_splits > b.n_splits); - const int16 most_splits = a_has_more_splits ? a.n_splits : b.n_splits; - const int16 a_extra_splits = a_has_more_splits ? most_splits - b.n_splits : 0; - const int16 b_extra_splits = a_has_more_splits ? 0 : most_splits - a.n_splits; + const split_int most_splits = a_has_more_splits ? a.n_splits : b.n_splits; + const split_int a_extra_splits = a_has_more_splits ? most_splits - b.n_splits : 0; + const split_int b_extra_splits = a_has_more_splits ? 0 : most_splits - a.n_splits; const double n_tips_reciprocal = 1.0 / n_tips; if (most_splits == 0 || n_tips == 0) { @@ -412,35 +422,36 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, cost_matrix score(most_splits); double exact_match_score = 0; - int16 exact_matches = 0; + split_int exact_matches = 0; // vector zero-initializes [so does make_unique] // match will have one added to it so numbering follows R; hence 0 = UNMATCHED std::vector a_match(a.n_splits); - std::unique_ptr b_match = std::make_unique(b.n_splits); + std::unique_ptr b_match = std::make_unique(b.n_splits); - for (int16 ai = 0; ai < a.n_splits; ++ai) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); if (a_match[ai]) continue; - const int16 na = a.in_split[ai]; - const int16 nA = n_tips - na; + const split_int na = a.in_split[ai]; + const split_int nA = n_tips - na; const auto *a_row = a.state[ai]; const double offset_a = lg2_n - lg2[na]; const double offset_A = lg2_n - lg2[nA]; - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { // x divides tips into a|A; y divides tips into b|B - int16 a_and_b = 0; + split_int a_and_b = 0; const auto *b_row = b.state[bi]; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + for (split_int bin = 0; bin < a.n_bins; ++bin) { a_and_b += count_bits(a_row[bin] & b_row[bin]); } - const int16 nb = b.in_split[bi]; - const int16 nB = n_tips - nb; - const int16 a_and_B = na - a_and_b; - const int16 A_and_b = nb - a_and_b; - const int16 A_and_B = nA - A_and_b; + const split_int nb = b.in_split[bi]; + const split_int nB = n_tips - nb; + const split_int a_and_B = na - a_and_b; + const split_int A_and_b = nb - a_and_b; + const split_int A_and_B = nA - A_and_b; if ((!a_and_B && !A_and_b) || (!a_and_b && !A_and_B)) { @@ -478,7 +489,7 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, _["matching"] = a_match); } - const int16 lap_dim = most_splits - exact_matches; + const split_int lap_dim = most_splits - exact_matches; ASSERT(lap_dim > 0); std::vector rowsol; std::vector colsol; @@ -487,11 +498,11 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, cost_matrix small_score(lap_dim); if (exact_matches) { - int16 a_pos = 0; - for (int16 ai = 0; ai < a.n_splits; ++ai) { + split_int a_pos = 0; + for (split_int ai = 0; ai < a.n_splits; ++ai) { if (a_match[ai]) continue; - int16 b_pos = 0; - for (int16 bi = 0; bi < b.n_splits; ++bi) { + split_int b_pos = 0; + for (split_int bi = 0; bi < b.n_splits; ++bi) { if (b_match[bi]) continue; small_score(a_pos, b_pos) = score(ai, bi); b_pos++; @@ -505,9 +516,9 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, lap(lap_dim, small_score, rowsol, colsol)) * over_max_score; const double final_score = lap_score + (exact_match_score / n_tips); - std::unique_ptr lap_decode = std::make_unique(lap_dim); - int16 fuzzy_match = 0; - for (int16 bi = 0; bi < b.n_splits; ++bi) { + std::unique_ptr lap_decode = std::make_unique(lap_dim); + split_int fuzzy_match = 0; + for (split_int bi = 0; bi < b.n_splits; ++bi) { if (!b_match[bi]) { assert(fuzzy_match < lap_dim); lap_decode[fuzzy_match++] = bi + 1; @@ -517,13 +528,13 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, fuzzy_match = 0; std::vector final_matching; TreeDist::resize_uninitialized(final_matching, a.n_splits); - for (int16 i = 0; i < a.n_splits; ++i) { + for (split_int i = 0; i < a.n_splits; ++i) { if (a_match[i]) { final_matching[i] = a_match[i]; } else { assert(fuzzy_match < lap_dim); - const int16 row_idx = fuzzy_match++; - const int16 col_idx = rowsol[row_idx]; + const split_int row_idx = fuzzy_match++; + const split_int col_idx = rowsol[row_idx]; final_matching[i] = (col_idx >= lap_dim - a_extra_splits) ? NA_INTEGER : lap_decode[col_idx]; } @@ -533,8 +544,8 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, _["matching"] = final_matching); } else { - for (int16 ai = a.n_splits; ai < most_splits; ++ai) { - for (int16 bi = 0; bi < most_splits; ++bi) { + for (split_int ai = a.n_splits; ai < most_splits; ++ai) { + for (split_int bi = 0; bi < most_splits; ++bi) { score(ai, bi) = max_score; } } @@ -545,7 +556,7 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, std::vector final_matching; final_matching.reserve(a.n_splits); - for (int16 i = 0; i < a.n_splits; ++i) { + for (split_int i = 0; i < a.n_splits; ++i) { const int match = (rowsol[i] < b.n_splits) ? static_cast(rowsol[i]) + 1 : NA_INTEGER; @@ -560,8 +571,8 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, inline List shared_phylo (const RawMatrix &x, const RawMatrix &y, const int32 n_tips) { const SplitList a(x), b(y); - const int16 most_splits = std::max(a.n_splits, b.n_splits); - const int16 overlap_a = int16(n_tips + 1) / 2; // avoids promotion to int + const split_int most_splits = std::max(a.n_splits, b.n_splits); + const split_int overlap_a = split_int(n_tips + 1) / 2; // avoids promotion to int constexpr cost max_score = BIG; const double lg2_unrooted_n = lg2_unrooted[n_tips]; @@ -574,8 +585,9 @@ inline List shared_phylo (const RawMatrix &x, const RawMatrix &y, // In/out direction [i.e. 1/0 bit] is arbitrary. cost_matrix score(most_splits); - for (int16 ai = a.n_splits; ai--; ) { - for (int16 bi = b.n_splits; bi--; ) { + for (split_int ai = a.n_splits; ai--; ) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); + for (split_int bi = b.n_splits; bi--; ) { const double spi_over = TreeDist::spi_overlap( a.state[ai], b.state[bi], n_tips, a.in_split[ai], b.in_split[bi], a.n_bins); @@ -601,7 +613,7 @@ inline List shared_phylo (const RawMatrix &x, const RawMatrix &y, std::vector final_matching; final_matching.reserve(a.n_splits); - for (int16 i = 0; i < a.n_splits; ++i) { + for (split_int i = 0; i < a.n_splits; ++i) { const int match = (rowsol[i] < b.n_splits) ? static_cast(rowsol[i]) + 1 : NA_INTEGER; diff --git a/src/tree_distances.h b/src/tree_distances.h index 8d4f7b4f..7e59a061 100644 --- a/src/tree_distances.h +++ b/src/tree_distances.h @@ -19,17 +19,20 @@ constexpr splitbit ALL_ONES = (std::numeric_limits::max)(); namespace TreeDist { // Re-exported from mutual_clustering.h: - // ic_matching(int16 a, int16 b, int16 n) + // ic_matching(split_int a, split_int b, split_int n) + + void check_ntip(const int32 n); // See equation 16 in Meila 2007 (k' denoted K). // nkK is converted to pkK in the calling function when divided by n. - inline void add_ic_element(double& ic_sum, const int16 nkK, const int16 nk, - const int16 nK, const int16 n_tips, + inline void add_ic_element(double& ic_sum, const split_int nkK, + const split_int nk, const split_int nK, + const split_int n_tips, const double lg2_n) noexcept { if (nkK && nk && nK) { assert(!(nkK == nk && nkK == nK && nkK << 1 == n_tips)); - const int32 numerator = nkK * n_tips; - const int32 denominator = nk * nK; + const int64_t numerator = static_cast(nkK) * n_tips; + const int64_t denominator = static_cast(nk) * nK; if (numerator != denominator) { ic_sum += nkK * (lg2[nkK] + lg2_n - lg2[nk] - lg2[nK]); } @@ -38,14 +41,15 @@ namespace TreeDist { // Returns lg2_unrooted[x] - lg2_trees_matching_split(y, x - y) - [[nodiscard]] inline double mmsi_pair_score(const int16 x, const int16 y) noexcept { - assert(SL_MAX_TIPS + 2 <= std::numeric_limits::max()); // verify int16 ok - + [[nodiscard]] inline double mmsi_pair_score(const split_int x, + const split_int y) noexcept { return lg2_unrooted[x] - (lg2_rooted[y] + lg2_rooted[x - y]); } - [[nodiscard]] inline double mmsi_score(const int16 n_same, const int16 n_a_and_b, - const int16 n_different, const int16 n_a_only) noexcept { + [[nodiscard]] inline double mmsi_score(const split_int n_same, + const split_int n_a_and_b, + const split_int n_different, + const split_int n_a_only) noexcept { if (n_same == 0 || n_same == n_a_and_b) return mmsi_pair_score(n_different, n_a_only); if (n_different == 0 || n_different == n_a_only) @@ -59,20 +63,21 @@ namespace TreeDist { } -[[nodiscard]] inline double one_overlap(const int16 a, const int16 b, const int16 n) noexcept { - assert(SL_MAX_TIPS + 2 <= std::numeric_limits::max()); // verify int16 ok +[[nodiscard]] inline double one_overlap(const split_int a, const split_int b, + const split_int n) noexcept { if (a == b) { return lg2_rooted[a] + lg2_rooted[n - a]; } // Unify ab via lo/hi: removes an unpredictable branch. - const int16 lo = (a < b) ? a : b; - const int16 hi = (a < b) ? b : a; + const split_int lo = (a < b) ? a : b; + const split_int hi = (a < b) ? b : a; return lg2_rooted[hi] + lg2_rooted[n - lo] - lg2_rooted[hi - lo + 1]; } - [[nodiscard]] inline double one_overlap_notb(const int16 a, const int16 n_minus_b, const int16 n) noexcept { - assert(SL_MAX_TIPS + 2 <= std::numeric_limits::max()); // verify int16 ok - const int16 b = n - n_minus_b; + [[nodiscard]] inline double one_overlap_notb(const split_int a, + const split_int n_minus_b, + const split_int n) noexcept { + const split_int b = n - n_minus_b; if (a == b) { return lg2_rooted[b] + lg2_rooted[n_minus_b]; } else if (a < b) { @@ -83,17 +88,15 @@ namespace TreeDist { } -// Popcount-based: single pass over bins replaces 4 sequential boolean scans. + // Popcount-based: single pass over bins replaces 4 sequential boolean scans. // Counts n_ab = |A ∩ B| via hardware POPCNT, then derives all 4 Venn-diagram // region populations from arithmetic on n_ab, in_a, in_b, n_tips. [[nodiscard]] inline double spi_overlap(const splitbit* a_state, const splitbit* b_state, - const int16 n_tips, const int16 in_a, - const int16 in_b, const int16 n_bins) noexcept { - - assert(SL_MAX_BINS <= INT16_MAX); + const split_int n_tips, const split_int in_a, + const split_int in_b, const split_int n_bins) noexcept { - int16 n_ab = 0; - for (int16 bin = 0; bin < n_bins; ++bin) { + split_int n_ab = 0; + for (split_int bin = 0; bin < n_bins; ++bin) { n_ab += TreeTools::count_bits(a_state[bin] & b_state[bin]); } diff --git a/tests/testthat/test-large-trees.R b/tests/testthat/test-large-trees.R new file mode 100644 index 00000000..01cb7f8e --- /dev/null +++ b/tests/testthat/test-large-trees.R @@ -0,0 +1,90 @@ +# Tests for large-tree support (> 2048 tips). +# +# These tests are guarded by skip_if(.SL_MAX_TIPS < required_tips) so they +# run only after TreeTools raises SL_MAX_TIPS and TreeDist is rebuilt. + +test_that("R-level guard rejects trees exceeding .SL_MAX_TIPS", { + skip_on_cran() + too_many <- .SL_MAX_TIPS + 1L + t1 <- TreeTools::RandomTree(too_many) + t2 <- TreeTools::RandomTree(too_many) + + expect_error(ClusteringInfoDistance(t1, t2), "not yet supported") + expect_error(PhylogeneticInfoDistance(t1, t2), "not yet supported") + expect_error(MatchingSplitDistance(t1, t2), "not yet supported") + expect_error(MatchingSplitInfoDistance(t1, t2), "not yet supported") + expect_error(InfoRobinsonFoulds(t1, t2), "not yet supported") + expect_error(NyeSimilarity(t1, t2), "not yet supported") +}) + +test_that("Batch path rejects trees exceeding .SL_MAX_TIPS", { + skip_on_cran() + too_many <- .SL_MAX_TIPS + 1L + trees <- lapply(1:3, function(i) TreeTools::RandomTree(too_many)) + class(trees) <- "multiPhylo" + + expect_error(ClusteringInfoDistance(trees), "not yet supported") + expect_error(PhylogeneticInfoDistance(trees), "not yet supported") +}) + +test_that("CID works for 4000-tip trees", { + skip_on_cran() + skip_if(.SL_MAX_TIPS < 4000L, "SL_MAX_TIPS not yet raised to 4000+") + + set.seed(7042) + t1 <- TreeTools::RandomTree(4000) + t2 <- TreeTools::RandomTree(4000) + + cid <- ClusteringInfoDistance(t1, t2) + expect_type(cid, "double") + expect_true(is.finite(cid)) + expect_gte(cid, 0) +}) + +test_that("Multiple metrics agree on identical 4000-tip trees", { + skip_on_cran() + skip_if(.SL_MAX_TIPS < 4000L, "SL_MAX_TIPS not yet raised to 4000+") + + set.seed(3891) + t1 <- TreeTools::RandomTree(4000) + + expect_equal(ClusteringInfoDistance(t1, t1), 0) + expect_equal(MatchingSplitDistance(t1, t1), 0) + expect_equal(RobinsonFoulds(t1, t1), 0) + expect_equal(InfoRobinsonFoulds(t1, t1), 0) +}) + +test_that("Batch CID works for 4000-tip trees", { + skip_on_cran() + skip_if(.SL_MAX_TIPS < 4000L, "SL_MAX_TIPS not yet raised to 4000+") + + set.seed(5283) + trees <- lapply(1:5, function(i) TreeTools::RandomTree(4000)) + class(trees) <- "multiPhylo" + + d <- ClusteringInfoDistance(trees) + expect_s3_class(d, "dist") + expect_equal(attr(d, "Size"), 5L) + expect_true(all(is.finite(d))) + expect_true(all(d >= 0)) +}) + +test_that("RF and IRF work for 8000-tip trees", { + skip_on_cran() + skip_if(.SL_MAX_TIPS < 8000L, "SL_MAX_TIPS not yet raised to 8000+") + + set.seed(6174) + t1 <- TreeTools::RandomTree(8000) + t2 <- TreeTools::RandomTree(8000) + + rf <- RobinsonFoulds(t1, t2) + expect_type(rf, "double") + expect_true(is.finite(rf)) + expect_gte(rf, 0) + expect_equal(RobinsonFoulds(t1, t1), 0) + + irf <- InfoRobinsonFoulds(t1, t2) + expect_type(irf, "double") + expect_true(is.finite(irf)) + expect_gte(irf, 0) +}) diff --git a/tests/testthat/test-split_info.R b/tests/testthat/test-split_info.R index a31dc695..887c0f61 100644 --- a/tests/testthat/test-split_info.R +++ b/tests/testthat/test-split_info.R @@ -23,11 +23,11 @@ test_that("Split info calculated", { expect_error(consensus_info(trees, TRUE, -7)) expect_error( consensus_info(list(rtree(20000), rtree(20000)), TRUE, p = 1), - "This many leaves are not yet supported" + "not yet supported for consensus info" ) expect_error( consensus_info(list(rtree(20000), rtree(20000)), FALSE, p = 1), - "This many leaves are not yet supported" + "not yet supported for consensus info" ) diff --git a/tests/testthat/test-tree_distance_nni.R b/tests/testthat/test-tree_distance_nni.R index 20bcda23..413eb206 100644 --- a/tests/testthat/test-tree_distance_nni.R +++ b/tests/testthat/test-tree_distance_nni.R @@ -9,27 +9,31 @@ test_that("NNIDist() handles exceptions", { PectinateTree(as.character(1:8)))), "trees must bear identical labels") # R-level guard catches too-many-tips - expect_error(NNIDist(PectinateTree(40000), BalancedTree(40000)), "so many tips") + expect_error(NNIDist(PectinateTree(40000), BalancedTree(40000)), + "not yet supported for NNI") expect_error(NNIDist(BalancedTree(5), RootOnNode(BalancedTree(5), 1))) }) -test_that("NNIDist() at NNI_MAX_TIPS", { - maxTips <- 32768 - more <- maxTips + 1 +test_that("NNIDist() at max tips", { + maxTips <- .SL_MAX_TIPS + more <- maxTips + 1L expect_error(.NNIDistSingle(PectinateTree(more), BalancedTree(more), more), - "so many tips") + "not yet supported for NNI") goingQuickly <- TRUE skip_if(goingQuickly) heapTips <- 16384 + 1 + skip_if(maxTips < heapTips) skip_if_not_installed("testthat", "3.2.2") expect_no_error(.NNIDistSingle(PectinateTree(heapTips), BalancedTree(heapTips), heapTips)) + skip_if(maxTips < 32768L) + maxTips <- 32768L n <- .NNIDistSingle(PectinateTree(maxTips), BalancedTree(maxTips), - maxTips) + maxTips) expect_gt(n[["best_upper"]], n[["best_lower"]]) if (!is.na(n[["tight_upper"]])) { expect_gte(n[["tight_upper"]], n[["best_upper"]]) diff --git a/tests/testthat/test-tree_distance_utilities.R b/tests/testthat/test-tree_distance_utilities.R index 61f513d8..32315501 100644 --- a/tests/testthat/test-tree_distance_utilities.R +++ b/tests/testthat/test-tree_distance_utilities.R @@ -32,23 +32,27 @@ test_that("CalculateTreeDistance() errs appropriately", { }) test_that("Tip-count guard is applied consistently", { - expect_no_error(.AssertNtipSupported(1000L)) - expect_no_error(.AssertNtipSupported(32766L)) - expect_no_error(.AssertNtipSupported(32767L)) - expect_error(.AssertNtipSupported(32768L), - "This many tips are not \\(yet\\) supported\\.") + expect_true(is.numeric(.SL_MAX_TIPS)) + expect_gt(.SL_MAX_TIPS, 0) + + expect_no_error(.CheckMaxTips(min(1000L, .SL_MAX_TIPS))) + expect_no_error(.CheckMaxTips(.SL_MAX_TIPS)) + + overLimit <- .SL_MAX_TIPS + 1L + expect_error(.CheckMaxTips(overLimit), + "Trees with > .* tips are not yet supported") splits8 <- unclass(as.Splits(BalancedTree(8))) - expect_error(cpp_robinson_foulds_distance(splits8, splits8, 32768L), - "This many tips are not \\(yet\\) supported\\.") - expect_error(cpp_robinson_foulds_info(splits8, splits8, 32768L), - "This many tips are not \\(yet\\) supported\\.") + expect_error(cpp_robinson_foulds_distance(splits8, splits8, overLimit), + "Requested nTip") + expect_error(cpp_robinson_foulds_info(splits8, splits8, overLimit), + "Requested nTip") trees <- list(BalancedTree(8), PectinateTree(8)) class(trees) <- "multiPhylo" expect_error( - .SplitDistanceAllPairs(RobinsonFouldsSplits, trees, letters[1:8], 32768L), - "This many tips are not \\(yet\\) supported\\." + .SplitDistanceAllPairs(RobinsonFouldsSplits, trees, letters[1:8], overLimit), + "Trees with > .* tips are not yet supported" ) }) From b9394eb7847c81b1bc9f7f1d9b4d18d4db32b99f Mon Sep 17 00:00:00 2001 From: R script <1695515+ms609@users.noreply.github.com> Date: Fri, 17 Apr 2026 09:26:41 +0100 Subject: [PATCH 02/11] Refine large-tree tests and guard evidence Use deterministic 4000-tip as.phylo near-neighbour trees for fast shortcut paths, remove redundant 8000-tip split coverage, and assert known RF expected values plus batch/pairwise agreement. Add a source-level regression test that verifies serial interrupt checks and batch n_tip guards remain wired in C++ paths. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/testthat/test-large-trees.R | 77 +++++-------------- tests/testthat/test-tree_distance_utilities.R | 18 +++++ 2 files changed, 39 insertions(+), 56 deletions(-) diff --git a/tests/testthat/test-large-trees.R b/tests/testthat/test-large-trees.R index 01cb7f8e..4fa24a30 100644 --- a/tests/testthat/test-large-trees.R +++ b/tests/testthat/test-large-trees.R @@ -6,8 +6,8 @@ test_that("R-level guard rejects trees exceeding .SL_MAX_TIPS", { skip_on_cran() too_many <- .SL_MAX_TIPS + 1L - t1 <- TreeTools::RandomTree(too_many) - t2 <- TreeTools::RandomTree(too_many) + t1 <- as.phylo(0, too_many) + t2 <- as.phylo(1, too_many) expect_error(ClusteringInfoDistance(t1, t2), "not yet supported") expect_error(PhylogeneticInfoDistance(t1, t2), "not yet supported") @@ -20,71 +20,36 @@ test_that("R-level guard rejects trees exceeding .SL_MAX_TIPS", { test_that("Batch path rejects trees exceeding .SL_MAX_TIPS", { skip_on_cran() too_many <- .SL_MAX_TIPS + 1L - trees <- lapply(1:3, function(i) TreeTools::RandomTree(too_many)) + trees <- as.phylo(0:2, too_many) class(trees) <- "multiPhylo" expect_error(ClusteringInfoDistance(trees), "not yet supported") expect_error(PhylogeneticInfoDistance(trees), "not yet supported") }) -test_that("CID works for 4000-tip trees", { +test_that("Known-answer large-tree near-neighbours (4000 tips)", { skip_on_cran() skip_if(.SL_MAX_TIPS < 4000L, "SL_MAX_TIPS not yet raised to 4000+") - set.seed(7042) - t1 <- TreeTools::RandomTree(4000) - t2 <- TreeTools::RandomTree(4000) - - cid <- ClusteringInfoDistance(t1, t2) - expect_type(cid, "double") - expect_true(is.finite(cid)) - expect_gte(cid, 0) -}) - -test_that("Multiple metrics agree on identical 4000-tip trees", { - skip_on_cran() - skip_if(.SL_MAX_TIPS < 4000L, "SL_MAX_TIPS not yet raised to 4000+") - - set.seed(3891) - t1 <- TreeTools::RandomTree(4000) - - expect_equal(ClusteringInfoDistance(t1, t1), 0) - expect_equal(MatchingSplitDistance(t1, t1), 0) - expect_equal(RobinsonFoulds(t1, t1), 0) - expect_equal(InfoRobinsonFoulds(t1, t1), 0) -}) - -test_that("Batch CID works for 4000-tip trees", { - skip_on_cran() - skip_if(.SL_MAX_TIPS < 4000L, "SL_MAX_TIPS not yet raised to 4000+") - - set.seed(5283) - trees <- lapply(1:5, function(i) TreeTools::RandomTree(4000)) - class(trees) <- "multiPhylo" - - d <- ClusteringInfoDistance(trees) - expect_s3_class(d, "dist") - expect_equal(attr(d, "Size"), 5L) - expect_true(all(is.finite(d))) - expect_true(all(d >= 0)) -}) - -test_that("RF and IRF work for 8000-tip trees", { - skip_on_cran() - skip_if(.SL_MAX_TIPS < 8000L, "SL_MAX_TIPS not yet raised to 8000+") - - set.seed(6174) - t1 <- TreeTools::RandomTree(8000) - t2 <- TreeTools::RandomTree(8000) + # Similar deterministic trees exercise shortcut paths and run quickly. + t1 <- as.phylo(0, 4000) + t2 <- as.phylo(1, 4000) + trees <- structure(list(t1, t2), class = "multiPhylo") + # Known answer for adjacent `as.phylo()` trees: one non-shared split per tree. rf <- RobinsonFoulds(t1, t2) - expect_type(rf, "double") - expect_true(is.finite(rf)) - expect_gte(rf, 0) - expect_equal(RobinsonFoulds(t1, t1), 0) + expect_equal(rf, 2) + expect_equal(as.matrix(RobinsonFoulds(trees))[2, 1], 2) + # Other large-tree metrics should be finite and non-negative. + cid <- ClusteringInfoDistance(t1, t2) + msd <- MatchingSplitDistance(t1, t2) irf <- InfoRobinsonFoulds(t1, t2) - expect_type(irf, "double") - expect_true(is.finite(irf)) - expect_gte(irf, 0) + expect_true(all(is.finite(c(cid, msd, irf)))) + expect_true(all(c(cid, msd, irf) >= 0)) + + # Batch and pairwise paths must agree. + expect_equal(as.matrix(ClusteringInfoDistance(trees))[2, 1], cid, tolerance = 1e-10) + expect_equal(as.matrix(MatchingSplitDistance(trees))[2, 1], msd, tolerance = 1e-10) + expect_equal(as.matrix(InfoRobinsonFoulds(trees))[2, 1], irf, tolerance = 1e-10) }) diff --git a/tests/testthat/test-tree_distance_utilities.R b/tests/testthat/test-tree_distance_utilities.R index 32315501..330ce707 100644 --- a/tests/testthat/test-tree_distance_utilities.R +++ b/tests/testthat/test-tree_distance_utilities.R @@ -56,6 +56,24 @@ test_that("Tip-count guard is applied consistently", { ) }) +test_that("Interrupt and tip-limit guards are wired in C++ distance paths", { + serial_src <- readLines( + testthat::test_path("..", "..", "src", "tree_distances.cpp"), + warn = FALSE + ) + interrupt_lines <- grep("checkUserInterrupt\\(", serial_src, value = TRUE) + throttled_lines <- grep("\\(ai & 1023\\) == 0", serial_src, value = TRUE) + expect_gte(length(interrupt_lines), 7L) + expect_gte(length(throttled_lines), 7L) + + batch_src <- readLines( + testthat::test_path("..", "..", "src", "pairwise_distances.cpp"), + warn = FALSE + ) + batch_guard_lines <- grep("TreeDist::check_ntip\\(", batch_src, value = TRUE) + expect_gte(length(batch_guard_lines), 14L) +}) + test_that("CalculateTreeDistance() handles splits appropriately", { set.seed(101) tree10 <- ape::rtree(10) From 310f13f3674e71a4e1afbe8deb0079cf8cc91df1 Mon Sep 17 00:00:00 2001 From: R script <1695515+ms609@users.noreply.github.com> Date: Fri, 17 Apr 2026 09:48:05 +0100 Subject: [PATCH 03/11] Skip source-wiring guard test in installed-package checks Guard the C++ source-inspection test so it runs in source-tree test sessions but skips in installed-package testthat runs where src/*.cpp files are unavailable. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/testthat/test-tree_distance_utilities.R | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/testthat/test-tree_distance_utilities.R b/tests/testthat/test-tree_distance_utilities.R index 330ce707..9376dc6a 100644 --- a/tests/testthat/test-tree_distance_utilities.R +++ b/tests/testthat/test-tree_distance_utilities.R @@ -57,19 +57,20 @@ test_that("Tip-count guard is applied consistently", { }) test_that("Interrupt and tip-limit guards are wired in C++ distance paths", { - serial_src <- readLines( - testthat::test_path("..", "..", "src", "tree_distances.cpp"), - warn = FALSE + serial_path <- testthat::test_path("..", "..", "src", "tree_distances.cpp") + batch_path <- testthat::test_path("..", "..", "src", "pairwise_distances.cpp") + skip_if_not( + file.exists(serial_path) && file.exists(batch_path), + "C++ source files unavailable in installed-package checks" ) + + serial_src <- readLines(serial_path, warn = FALSE) interrupt_lines <- grep("checkUserInterrupt\\(", serial_src, value = TRUE) throttled_lines <- grep("\\(ai & 1023\\) == 0", serial_src, value = TRUE) expect_gte(length(interrupt_lines), 7L) expect_gte(length(throttled_lines), 7L) - batch_src <- readLines( - testthat::test_path("..", "..", "src", "pairwise_distances.cpp"), - warn = FALSE - ) + batch_src <- readLines(batch_path, warn = FALSE) batch_guard_lines <- grep("TreeDist::check_ntip\\(", batch_src, value = TRUE) expect_gte(length(batch_guard_lines), 14L) }) From b8440e96f041fb7bbda97c122f92721736cac03d Mon Sep 17 00:00:00 2001 From: R script <1695515+ms609@users.noreply.github.com> Date: Fri, 17 Apr 2026 10:58:58 +0100 Subject: [PATCH 04/11] Stop treating SL_MAX_TIPS as a hard cap Remove the .SL_MAX_TIPS global assignment and make tip-limit checks reflect real constraints. Use TreeTools SL_MAX_TIPS as a cache/stack threshold only: add log lookup fallbacks for larger trees, update MCI/SPI/MSI/CID kernels to call safe lookup helpers, and keep runtime guards based on integer/type limits. Add explicit runtime NNI guard at 32768 tips, keep algorithm-specific caps for MAST and consensus-info, and update large-tree/guard tests to validate fast 4000-tip known-answer behaviour and new guard semantics. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- R/tree_distance_mast.R | 2 +- R/tree_distance_utilities.R | 43 ++++++++--------- R/tree_information.R | 2 +- R/zzz.R | 4 -- inst/include/TreeDist/mutual_clustering.h | 45 +++++++++++++++-- .../include/TreeDist/mutual_clustering_impl.h | 18 +++---- src/nni_distance.cpp | 10 +++- src/pairwise_distances.cpp | 33 ++++++------- src/tree_distance_functions.cpp | 9 +++- src/tree_distances.cpp | 48 ++++++++++--------- src/tree_distances.h | 19 +++++--- tests/testthat/test-large-trees.R | 39 +++------------ tests/testthat/test-tree_distance_nni.R | 2 +- tests/testthat/test-tree_distance_utilities.R | 28 +++++------ 14 files changed, 163 insertions(+), 139 deletions(-) diff --git a/R/tree_distance_mast.R b/R/tree_distance_mast.R index beee13d1..eae741ff 100644 --- a/R/tree_distance_mast.R +++ b/R/tree_distance_mast.R @@ -96,7 +96,7 @@ MASTSize <- function(tree1, tree2 = tree1, rooted = TRUE) { if (nrow(edge1) != nrow(edge2)) { stop("Both trees must contain the same number of edges.") } - maxTips <- min(4096L, if (is.null(.SL_MAX_TIPS)) cpp_max_tips() else .SL_MAX_TIPS) + maxTips <- min(4096L, cpp_max_tips()) if (nTip > maxTips) { stop("Trees with > ", maxTips, " tips are not yet supported for MAST.") } diff --git a/R/tree_distance_utilities.R b/R/tree_distance_utilities.R index 1bae051e..3d02258e 100644 --- a/R/tree_distance_utilities.R +++ b/R/tree_distance_utilities.R @@ -1,3 +1,22 @@ +.CheckMaxTips <- function(nTip, context = "") { + if (is.na(nTip)) { + return(invisible(NULL)) + } + + # Global limit from C++ integer types (not TreeTools stack thresholds). + maxTips <- cpp_max_tips() + if (nTip > maxTips) { + suffix <- if (!nzchar(context)) "." else paste0(" for ", context, ".") + stop("Trees with > ", maxTips, " tips are not yet supported", suffix) + } + + # NNI uses fixed-size lookup tables in C++. + if (identical(context, "NNI") && nTip > 32768L) { + stop("Trees with > 32768 tips are not yet supported for NNI.") + } + invisible(NULL) +} + #' Wrapper for tree distance calculations #' #' Calls tree distance functions from trees or lists of trees @@ -11,28 +30,6 @@ #' @importFrom TreeTools as.Splits TipLabels #' @importFrom utils combn #' @export -# Maximum number of tips supported by this compiled package build. -# Set during .onLoad() from `cpp_max_tips()`. -.SL_MAX_TIPS <- NULL - -.CheckMaxTips <- function(nTip, context = "") { - if (!is.na(nTip)) { - maxTips <- .SL_MAX_TIPS - if (is.null(maxTips) || is.na(maxTips)) { - maxTips <- cpp_max_tips() - .SL_MAX_TIPS <<- maxTips - } - if (nTip > maxTips) { - suffix <- if (!nzchar(context)) "." else paste0(" for ", context, ".") - stop("Trees with > ", maxTips, " tips are not yet supported", suffix) - } - } - invisible(NULL) -} - -# Backward-compatible alias for internal callers/tests. -.AssertNtipSupported <- .CheckMaxTips - CalculateTreeDistance <- function(Func, tree1, tree2 = NULL, reportMatching = FALSE, ...) { supportedClasses <- c("phylo", "Splits") @@ -344,7 +341,7 @@ CalculateTreeDistance <- function(Func, tree1, tree2 = NULL, #' @param checks Logical specifying whether to perform basic sanity checks to #' avoid crashes in C++. #' @keywords internal -#' @seealso [`CalculateTreeDistance`] +#' @seealso [`CalculateTreeDistance()`] #' @export .TreeDistance <- function(Func, tree1, tree2, checks = TRUE, ...) { single1 <- inherits(tree1, "phylo") diff --git a/R/tree_information.R b/R/tree_information.R index f1d44d97..81672e50 100644 --- a/R/tree_information.R +++ b/R/tree_information.R @@ -391,7 +391,7 @@ consensus_info <- function(trees, phylo, p) { } nTip <- NTip(trees[[1]]) # CT_MAX_LEAVES = 16383 in information.h (lookup-table size limit). - maxTips <- min(16383L, if (is.null(.SL_MAX_TIPS)) cpp_max_tips() else .SL_MAX_TIPS) + maxTips <- min(16383L, cpp_max_tips()) if (nTip > maxTips) { stop("Trees with > ", maxTips, " tips are not yet supported for consensus info.") } diff --git a/R/zzz.R b/R/zzz.R index a8e10a2c..826528ea 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -1,7 +1,3 @@ -.onLoad <- function(libname, pkgname) { - .SL_MAX_TIPS <<- cpp_max_tips() -} - .onUnload <- function(libpath) { StopParallel(quietly = TRUE) library.dynam.unload("TreeDist", libpath) diff --git a/inst/include/TreeDist/mutual_clustering.h b/inst/include/TreeDist/mutual_clustering.h index 7d404ed0..99e9389d 100644 --- a/inst/include/TreeDist/mutual_clustering.h +++ b/inst/include/TreeDist/mutual_clustering.h @@ -15,6 +15,8 @@ namespace TreeDist { + constexpr double LOG2_E = 1.4426950408889634; + // ---- Lookup tables (populated by init_lg2_tables) ---- // // lg2[i] = log2(i) for 0 <= i <= SL_MAX_TIPS @@ -22,7 +24,9 @@ namespace TreeDist { // lg2_unrooted[i] = log2((2i-5)!!) for i >= 3 // lg2_rooted = &lg2_unrooted[0] + 1 (so lg2_rooted[i] = lg2_unrooted[i+1]) // - // These are defined in mutual_clustering_impl.h. + // These are fast-path caches sized to TreeTools' stack threshold. + // For larger trees we fall back to on-the-fly computation. + // Definitions are in mutual_clustering_impl.h. extern double lg2[SL_MAX_TIPS + 1]; extern double lg2_double_factorial[SL_MAX_TIPS + SL_MAX_TIPS - 2]; @@ -33,15 +37,48 @@ namespace TreeDist { // computation. max_tips should be >= the largest tree size used. void init_lg2_tables(int max_tips); + [[nodiscard]] inline double lg2_lookup(split_int x) noexcept { + if (x <= static_cast(SL_MAX_TIPS)) { + return lg2[x]; + } + return std::log2(static_cast(x)); + } + + [[nodiscard]] inline double lg2_unrooted_lookup(split_int n_tips) noexcept { + if (n_tips <= static_cast(SL_MAX_TIPS + 1)) { + return lg2_unrooted[n_tips]; + } + if (n_tips < 3) { + return 0.0; + } + const double n = static_cast(n_tips); + // log2((2n - 5)!!) = log2((2n - 4)!) - (n - 2) - log2((n - 2)!) + return (std::lgamma((2.0 * n) - 3.0) - std::lgamma(n - 1.0)) * LOG2_E + - (n - 2.0); + } + + [[nodiscard]] inline double lg2_rooted_lookup(split_int n_tips) noexcept { + if (n_tips <= static_cast(SL_MAX_TIPS + 1)) { + return lg2_rooted[n_tips]; + } + if (n_tips < 2) { + return 0.0; + } + const double n = static_cast(n_tips); + // log2((2n - 3)!!) = log2((2n - 2)!) - (n - 1) - log2((n - 1)!) + return (std::lgamma((2.0 * n) - 1.0) - std::lgamma(n)) * LOG2_E + - (n - 1.0); + } + // ---- Inline helpers ---- // Information content of a perfectly-matching split pair. // ic_matching(a, b, n) = (a + b) * lg2[n] - a * lg2[a] - b * lg2[b] [[nodiscard]] inline double ic_matching(split_int a, split_int b, split_int n) noexcept { - const double lg2a = lg2[a]; - const double lg2b = lg2[b]; - const double lg2n = lg2[n]; + const double lg2a = lg2_lookup(a); + const double lg2b = lg2_lookup(b); + const double lg2n = lg2_lookup(n); return (a + b) * lg2n - a * lg2a - b * lg2b; } diff --git a/inst/include/TreeDist/mutual_clustering_impl.h b/inst/include/TreeDist/mutual_clustering_impl.h index 286b55a3..884a18b7 100644 --- a/inst/include/TreeDist/mutual_clustering_impl.h +++ b/inst/include/TreeDist/mutual_clustering_impl.h @@ -178,7 +178,7 @@ double mutual_clustering_score( constexpr cost max_score = BIG; constexpr double over_max = 1.0 / static_cast(BIG); const double max_over_tips = static_cast(BIG) * n_tips_rcp; - const double lg2_n = lg2[n_tips]; + const double lg2_n = lg2_lookup(static_cast(n_tips)); // --- Phase 1: O(n log n) exact-match detection --- std::vector a_match_buf(a_n_splits); @@ -233,8 +233,8 @@ double mutual_clustering_score( const split_int nA = static_cast(n_tips - na); const splitbit* a_row = a_state[ai]; - const double offset_a = lg2_n - lg2[na]; - const double offset_A = lg2_n - lg2[nA]; + const double offset_a = lg2_n - lg2_lookup(na); + const double offset_A = lg2_n - lg2_lookup(nA); for (split_int b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { const split_int bi = b_unmatch[b_pos]; @@ -254,13 +254,13 @@ double mutual_clustering_score( && a_and_b == A_and_B) { score(a_pos, b_pos) = max_score; } else { - const double lg2_nb = lg2[nb]; - const double lg2_nB = lg2[nB]; + const double lg2_nb = lg2_lookup(nb); + const double lg2_nB = lg2_lookup(nB); const double ic_sum = - a_and_b * (lg2[a_and_b] + offset_a - lg2_nb) + - a_and_B * (lg2[a_and_B] + offset_a - lg2_nB) + - A_and_b * (lg2[A_and_b] + offset_A - lg2_nb) + - A_and_B * (lg2[A_and_B] + offset_A - lg2_nB); + a_and_b * (lg2_lookup(a_and_b) + offset_a - lg2_nb) + + a_and_B * (lg2_lookup(a_and_B) + offset_a - lg2_nB) + + A_and_b * (lg2_lookup(A_and_b) + offset_A - lg2_nb) + + A_and_B * (lg2_lookup(A_and_B) + offset_A - lg2_nB); score(a_pos, b_pos) = max_score - static_cast(ic_sum * max_over_tips); } diff --git a/src/nni_distance.cpp b/src/nni_distance.cpp index db3a5b3c..8712670e 100644 --- a/src/nni_distance.cpp +++ b/src/nni_distance.cpp @@ -288,8 +288,14 @@ grf_match nni_rf_matching ( IntegerVector cpp_nni_distance(const IntegerMatrix& edge1, const IntegerMatrix& edge2, const IntegerVector& nTip) { - - ASSERT(nTip[0] <= NNI_MAX_TIPS && "Cannot calculate NNI distance for trees with so many tips."); + + if (nTip[0] > NNI_MAX_TIPS) { + Rcpp::stop("Trees with > %d tips are not yet supported for NNI.", + NNI_MAX_TIPS); + } + if (nTip[0] < 0) { + Rcpp::stop("Requested nTip = %d is invalid for NNI.", nTip[0]); + } const int32_t n_tip = static_cast(nTip[0]); const int32_t node_0 = n_tip; const int32_t node_0_r = n_tip + 1; diff --git a/src/pairwise_distances.cpp b/src/pairwise_distances.cpp index 3bde0444..16d40936 100644 --- a/src/pairwise_distances.cpp +++ b/src/pairwise_distances.cpp @@ -184,7 +184,7 @@ static double mutual_clustering_score( constexpr cost max_score = BIG; constexpr double over_max = 1.0 / static_cast(BIG); const double max_over_tips = static_cast(BIG) * n_tips_rcp; - const double lg2_n = lg2[n_tips]; + const double lg2_n = TreeDist::lg2_lookup(n_tips); // --- Phase 1: O(n log n) exact-match detection --- const split_int exact_n = find_exact_matches(a, b, n_tips, mscratch); @@ -232,8 +232,8 @@ static double mutual_clustering_score( const split_int nA = n_tips - na; const auto* a_row = a.state[ai]; - const double offset_a = lg2_n - lg2[na]; - const double offset_A = lg2_n - lg2[nA]; + const double offset_a = lg2_n - TreeDist::lg2_lookup(na); + const double offset_A = lg2_n - TreeDist::lg2_lookup(nA); for (split_int b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { const split_int bi = b_unmatch[b_pos]; @@ -252,13 +252,13 @@ static double mutual_clustering_score( if (a_and_b == A_and_b && a_and_b == a_and_B && a_and_b == A_and_B) { score(a_pos, b_pos) = max_score; } else { - const double lg2_nb = lg2[nb]; - const double lg2_nB = lg2[nB]; + const double lg2_nb = TreeDist::lg2_lookup(nb); + const double lg2_nB = TreeDist::lg2_lookup(nB); const double ic_sum = - a_and_b * (lg2[a_and_b] + offset_a - lg2_nb) + - a_and_B * (lg2[a_and_B] + offset_a - lg2_nB) + - A_and_b * (lg2[A_and_b] + offset_A - lg2_nb) + - A_and_B * (lg2[A_and_B] + offset_A - lg2_nB); + a_and_b * (TreeDist::lg2_lookup(a_and_b) + offset_a - lg2_nb) + + a_and_B * (TreeDist::lg2_lookup(a_and_B) + offset_a - lg2_nB) + + A_and_b * (TreeDist::lg2_lookup(A_and_b) + offset_A - lg2_nb) + + A_and_B * (TreeDist::lg2_lookup(A_and_B) + offset_A - lg2_nB); score(a_pos, b_pos) = max_score - static_cast(ic_sum * max_over_tips); } } @@ -376,7 +376,7 @@ static double rf_info_score( // Sum info contribution for each matched split in a const split_int* a_match = mscratch.a_match.data(); - const double lg2_unrooted_n = lg2_unrooted[n_tips]; + const double lg2_unrooted_n = TreeDist::lg2_unrooted_lookup(n_tips); double score = 0; for (split_int ai = 0; ai < a_n; ++ai) { if (a_match[ai] == 0) continue; @@ -385,8 +385,8 @@ static double rf_info_score( leaves_in_split += count_bits(a.state[ai][bin]); } score += lg2_unrooted_n - - lg2_rooted[leaves_in_split] - - lg2_rooted[n_tips - leaves_in_split]; + - TreeDist::lg2_rooted_lookup(leaves_in_split) + - TreeDist::lg2_rooted_lookup(n_tips - leaves_in_split); } return score; } @@ -578,9 +578,9 @@ static double msi_score( if (most_splits == 0) return 0.0; constexpr cost max_score = BIG; - const double max_possible = lg2_unrooted[n_tips] - - lg2_rooted[split_int((n_tips + 1) / 2)] - - lg2_rooted[split_int(n_tips / 2)]; + const double max_possible = TreeDist::lg2_unrooted_lookup(n_tips) + - TreeDist::lg2_rooted_lookup(split_int((n_tips + 1) / 2)) + - TreeDist::lg2_rooted_lookup(split_int(n_tips / 2)); const double score_over_possible = static_cast(max_score) / max_possible; const double possible_over_score = max_possible / static_cast(max_score); @@ -679,7 +679,8 @@ static double shared_phylo_score( const split_int overlap_a = split_int(n_tips + 1) / 2; constexpr cost max_score = BIG; const double best_overlap = TreeDist::one_overlap(overlap_a, n_tips / 2, n_tips); - const double max_possible = lg2_unrooted[n_tips] - best_overlap; + const double max_possible = TreeDist::lg2_unrooted_lookup(n_tips) - + best_overlap; const double score_over_possible = static_cast(max_score) / max_possible; const double possible_over_score = max_possible / static_cast(max_score); diff --git a/src/tree_distance_functions.cpp b/src/tree_distance_functions.cpp index 810cc1cb..7e5d892a 100644 --- a/src/tree_distance_functions.cpp +++ b/src/tree_distance_functions.cpp @@ -1,5 +1,7 @@ #include #include +#include +#include // Provide the MCI table definitions and implementation in this TU. #define TREEDIST_MCI_IMPLEMENTATION @@ -10,6 +12,7 @@ // Populate lookup tables at library load time. __attribute__((constructor)) void initialize_ldf() { + // Cache up to TreeTools' stack-threshold value. TreeDist::init_lg2_tables(SL_MAX_TIPS); } @@ -48,5 +51,9 @@ double cpp_mci_impl_score(const Rcpp::RawMatrix& x, // [[Rcpp::export]] int cpp_max_tips() { - return static_cast(SL_MAX_TIPS); + constexpr int64_t split_int_limit = + static_cast((std::numeric_limits::max)()); + constexpr int64_t int32_limit = + static_cast((std::numeric_limits::max)()); + return static_cast(std::min(split_int_limit, int32_limit)); } diff --git a/src/tree_distances.cpp b/src/tree_distances.cpp index 1ced9e26..58d0c3a5 100644 --- a/src/tree_distances.cpp +++ b/src/tree_distances.cpp @@ -30,19 +30,19 @@ namespace TreeDist { } void check_ntip(const int32 n) { - static_assert(SL_MAX_TIPS <= std::numeric_limits::max(), - "SL_MAX_TIPS must fit in int32"); - constexpr int32 compiled_tip_limit = static_cast(SL_MAX_TIPS); constexpr int64_t split_int_limit = static_cast((std::numeric_limits::max)()); - constexpr int64_t max_supported_tips = - std::min(compiled_tip_limit, split_int_limit); + constexpr int64_t max_supported_tips = split_int_limit; + + if (n < 0) { + Rcpp::stop("Requested nTip = %d is invalid.", n); + } if (n > max_supported_tips) { Rcpp::stop( - "Requested nTip = %d exceeds this TreeDist build limit (%d; " - "compiled SL_MAX_TIPS = %d): this many tips are not yet supported.", - n, static_cast(max_supported_tips), compiled_tip_limit + "Requested nTip = %d exceeds this TreeDist build limit (%d): " + "this many tips are not yet supported.", + n, static_cast(max_supported_tips) ); } } @@ -115,7 +115,7 @@ inline List robinson_foulds_info(const RawMatrix &x, const RawMatrix &y, SL_BIN_SIZE - n_tips % SL_BIN_SIZE : 0; const splitbit unset_mask = ALL_ONES >> unset_tips; - const double lg2_unrooted_n = lg2_unrooted[n_tips]; + const double lg2_unrooted_n = TreeDist::lg2_unrooted_lookup(n_tips); double score = 0; grf_match matching(a.n_splits, NA_INTEGER); @@ -155,8 +155,9 @@ inline List robinson_foulds_info(const RawMatrix &x, const RawMatrix &y, leaves_in_split += count_bits(a.state[ai][bin]); } - score += lg2_unrooted_n - lg2_rooted[leaves_in_split] - - lg2_rooted[n_tips - leaves_in_split]; + score += lg2_unrooted_n - + TreeDist::lg2_rooted_lookup(leaves_in_split) - + TreeDist::lg2_rooted_lookup(n_tips - leaves_in_split); matching[ai] = bi + 1; break; /* Only one match possible per split */ @@ -341,8 +342,9 @@ List msi_distance(const RawMatrix &x, const RawMatrix &y, const int32 n_tips) { const SplitList a(x), b(y); const split_int most_splits = std::max(a.n_splits, b.n_splits); constexpr cost max_score = BIG; - const double max_possible = lg2_unrooted[n_tips] - - lg2_rooted[split_int((n_tips + 1) / 2)] - lg2_rooted[split_int(n_tips / 2)]; + const double max_possible = TreeDist::lg2_unrooted_lookup(n_tips) - + TreeDist::lg2_rooted_lookup(split_int((n_tips + 1) / 2)) - + TreeDist::lg2_rooted_lookup(split_int(n_tips / 2)); const double score_over_possible = static_cast(max_score) / max_possible; const double possible_over_score = max_possible / max_score; @@ -417,7 +419,7 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, constexpr cost max_score = BIG; constexpr double over_max_score = 1.0 / static_cast(max_score); const double max_over_tips = static_cast(max_score) * n_tips_reciprocal; - const double lg2_n = lg2[n_tips]; + const double lg2_n = TreeDist::lg2_lookup(n_tips); cost_matrix score(most_splits); @@ -435,8 +437,8 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, const split_int na = a.in_split[ai]; const split_int nA = n_tips - na; const auto *a_row = a.state[ai]; - const double offset_a = lg2_n - lg2[na]; - const double offset_A = lg2_n - lg2[nA]; + const double offset_a = lg2_n - TreeDist::lg2_lookup(na); + const double offset_A = lg2_n - TreeDist::lg2_lookup(nA); for (split_int bi = 0; bi < b.n_splits; ++bi) { @@ -465,13 +467,13 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, && a_and_b == A_and_B) { score(ai, bi) = max_score; // Avoid rounding errors } else { - const double lg2_nb = lg2[nb]; - const double lg2_nB = lg2[nB]; + const double lg2_nb = TreeDist::lg2_lookup(nb); + const double lg2_nB = TreeDist::lg2_lookup(nB); const double ic_sum = - a_and_b * (lg2[a_and_b] + offset_a - lg2_nb) + - a_and_B * (lg2[a_and_B] + offset_a - lg2_nB) + - A_and_b * (lg2[A_and_b] + offset_A - lg2_nb) + - A_and_B * (lg2[A_and_B] + offset_A - lg2_nB); + a_and_b * (TreeDist::lg2_lookup(a_and_b) + offset_a - lg2_nb) + + a_and_B * (TreeDist::lg2_lookup(a_and_B) + offset_a - lg2_nB) + + A_and_b * (TreeDist::lg2_lookup(A_and_b) + offset_A - lg2_nb) + + A_and_B * (TreeDist::lg2_lookup(A_and_B) + offset_A - lg2_nB); // Division by n_tips converts n(A&B) to P(A&B) for each ic_element score(ai, bi) = max_score - static_cast(ic_sum * max_over_tips); @@ -575,7 +577,7 @@ inline List shared_phylo (const RawMatrix &x, const RawMatrix &y, const split_int overlap_a = split_int(n_tips + 1) / 2; // avoids promotion to int constexpr cost max_score = BIG; - const double lg2_unrooted_n = lg2_unrooted[n_tips]; + const double lg2_unrooted_n = TreeDist::lg2_unrooted_lookup(n_tips); const double best_overlap = TreeDist::one_overlap(overlap_a, n_tips / 2, n_tips); const double max_possible = lg2_unrooted_n - best_overlap; const double score_over_possible = max_score / max_possible; diff --git a/src/tree_distances.h b/src/tree_distances.h index 7e59a061..383a91d9 100644 --- a/src/tree_distances.h +++ b/src/tree_distances.h @@ -34,7 +34,8 @@ namespace TreeDist { const int64_t numerator = static_cast(nkK) * n_tips; const int64_t denominator = static_cast(nk) * nK; if (numerator != denominator) { - ic_sum += nkK * (lg2[nkK] + lg2_n - lg2[nk] - lg2[nK]); + ic_sum += nkK * (lg2_lookup(nkK) + lg2_n - lg2_lookup(nk) - + lg2_lookup(nK)); } } } @@ -43,7 +44,8 @@ namespace TreeDist { // Returns lg2_unrooted[x] - lg2_trees_matching_split(y, x - y) [[nodiscard]] inline double mmsi_pair_score(const split_int x, const split_int y) noexcept { - return lg2_unrooted[x] - (lg2_rooted[y] + lg2_rooted[x - y]); + return lg2_unrooted_lookup(x) - (lg2_rooted_lookup(y) + + lg2_rooted_lookup(x - y)); } [[nodiscard]] inline double mmsi_score(const split_int n_same, @@ -66,12 +68,13 @@ namespace TreeDist { [[nodiscard]] inline double one_overlap(const split_int a, const split_int b, const split_int n) noexcept { if (a == b) { - return lg2_rooted[a] + lg2_rooted[n - a]; + return lg2_rooted_lookup(a) + lg2_rooted_lookup(n - a); } // Unify ab via lo/hi: removes an unpredictable branch. const split_int lo = (a < b) ? a : b; const split_int hi = (a < b) ? b : a; - return lg2_rooted[hi] + lg2_rooted[n - lo] - lg2_rooted[hi - lo + 1]; + return lg2_rooted_lookup(hi) + lg2_rooted_lookup(n - lo) - + lg2_rooted_lookup(hi - lo + 1); } [[nodiscard]] inline double one_overlap_notb(const split_int a, @@ -79,11 +82,13 @@ namespace TreeDist { const split_int n) noexcept { const split_int b = n - n_minus_b; if (a == b) { - return lg2_rooted[b] + lg2_rooted[n_minus_b]; + return lg2_rooted_lookup(b) + lg2_rooted_lookup(n_minus_b); } else if (a < b) { - return lg2_rooted[b] + lg2_rooted[n - a] - lg2_rooted[b - a + 1]; + return lg2_rooted_lookup(b) + lg2_rooted_lookup(n - a) - + lg2_rooted_lookup(b - a + 1); } else { - return lg2_rooted[a] + lg2_rooted[n_minus_b] - lg2_rooted[a - b + 1]; + return lg2_rooted_lookup(a) + lg2_rooted_lookup(n_minus_b) - + lg2_rooted_lookup(a - b + 1); } } diff --git a/tests/testthat/test-large-trees.R b/tests/testthat/test-large-trees.R index 4fa24a30..e0d59f7e 100644 --- a/tests/testthat/test-large-trees.R +++ b/tests/testthat/test-large-trees.R @@ -1,35 +1,5 @@ -# Tests for large-tree support (> 2048 tips). -# -# These tests are guarded by skip_if(.SL_MAX_TIPS < required_tips) so they -# run only after TreeTools raises SL_MAX_TIPS and TreeDist is rebuilt. - -test_that("R-level guard rejects trees exceeding .SL_MAX_TIPS", { - skip_on_cran() - too_many <- .SL_MAX_TIPS + 1L - t1 <- as.phylo(0, too_many) - t2 <- as.phylo(1, too_many) - - expect_error(ClusteringInfoDistance(t1, t2), "not yet supported") - expect_error(PhylogeneticInfoDistance(t1, t2), "not yet supported") - expect_error(MatchingSplitDistance(t1, t2), "not yet supported") - expect_error(MatchingSplitInfoDistance(t1, t2), "not yet supported") - expect_error(InfoRobinsonFoulds(t1, t2), "not yet supported") - expect_error(NyeSimilarity(t1, t2), "not yet supported") -}) - -test_that("Batch path rejects trees exceeding .SL_MAX_TIPS", { - skip_on_cran() - too_many <- .SL_MAX_TIPS + 1L - trees <- as.phylo(0:2, too_many) - class(trees) <- "multiPhylo" - - expect_error(ClusteringInfoDistance(trees), "not yet supported") - expect_error(PhylogeneticInfoDistance(trees), "not yet supported") -}) - test_that("Known-answer large-tree near-neighbours (4000 tips)", { skip_on_cran() - skip_if(.SL_MAX_TIPS < 4000L, "SL_MAX_TIPS not yet raised to 4000+") # Similar deterministic trees exercise shortcut paths and run quickly. t1 <- as.phylo(0, 4000) @@ -49,7 +19,10 @@ test_that("Known-answer large-tree near-neighbours (4000 tips)", { expect_true(all(c(cid, msd, irf) >= 0)) # Batch and pairwise paths must agree. - expect_equal(as.matrix(ClusteringInfoDistance(trees))[2, 1], cid, tolerance = 1e-10) - expect_equal(as.matrix(MatchingSplitDistance(trees))[2, 1], msd, tolerance = 1e-10) - expect_equal(as.matrix(InfoRobinsonFoulds(trees))[2, 1], irf, tolerance = 1e-10) + expect_equal(unname(as.matrix(ClusteringInfoDistance(trees))[2, 1]), + unname(cid), tolerance = 1e-8) + expect_equal(unname(as.matrix(MatchingSplitDistance(trees))[2, 1]), + unname(msd), tolerance = 1e-8) + expect_equal(unname(as.matrix(InfoRobinsonFoulds(trees))[2, 1]), + unname(irf), tolerance = 1e-8) }) diff --git a/tests/testthat/test-tree_distance_nni.R b/tests/testthat/test-tree_distance_nni.R index 413eb206..46fd6361 100644 --- a/tests/testthat/test-tree_distance_nni.R +++ b/tests/testthat/test-tree_distance_nni.R @@ -17,7 +17,7 @@ test_that("NNIDist() handles exceptions", { }) test_that("NNIDist() at max tips", { - maxTips <- .SL_MAX_TIPS + maxTips <- 32768L more <- maxTips + 1L expect_error(.NNIDistSingle(PectinateTree(more), BalancedTree(more), more), "not yet supported for NNI") diff --git a/tests/testthat/test-tree_distance_utilities.R b/tests/testthat/test-tree_distance_utilities.R index 9376dc6a..b16c6296 100644 --- a/tests/testthat/test-tree_distance_utilities.R +++ b/tests/testthat/test-tree_distance_utilities.R @@ -32,27 +32,27 @@ test_that("CalculateTreeDistance() errs appropriately", { }) test_that("Tip-count guard is applied consistently", { - expect_true(is.numeric(.SL_MAX_TIPS)) - expect_gt(.SL_MAX_TIPS, 0) + maxTips <- cpp_max_tips() + expect_true(is.numeric(maxTips)) + expect_gt(maxTips, 0) - expect_no_error(.CheckMaxTips(min(1000L, .SL_MAX_TIPS))) - expect_no_error(.CheckMaxTips(.SL_MAX_TIPS)) - - overLimit <- .SL_MAX_TIPS + 1L - expect_error(.CheckMaxTips(overLimit), + expect_no_error(.CheckMaxTips(min(1000L, maxTips))) + expect_no_error(.CheckMaxTips(maxTips)) + expect_error(.CheckMaxTips(as.double(maxTips) + 1), "Trees with > .* tips are not yet supported") + expect_no_error(.CheckMaxTips(32768L, "NNI")) + expect_error(.CheckMaxTips(32769L, "NNI"), + "not yet supported for NNI") splits8 <- unclass(as.Splits(BalancedTree(8))) - expect_error(cpp_robinson_foulds_distance(splits8, splits8, overLimit), - "Requested nTip") - expect_error(cpp_robinson_foulds_info(splits8, splits8, overLimit), - "Requested nTip") + expect_no_error(cpp_robinson_foulds_distance(splits8, splits8, 8L)) + expect_no_error(cpp_robinson_foulds_info(splits8, splits8, 8L)) trees <- list(BalancedTree(8), PectinateTree(8)) class(trees) <- "multiPhylo" - expect_error( - .SplitDistanceAllPairs(RobinsonFouldsSplits, trees, letters[1:8], overLimit), - "Trees with > .* tips are not yet supported" + tipLabels <- TipLabels(trees[[1]]) + expect_no_error( + .SplitDistanceAllPairs(RobinsonFouldsSplits, trees, tipLabels, 8L) ) }) From 140500316183956c83dc1ba05897714caddcd5fb Mon Sep 17 00:00:00 2001 From: R script <1695515+ms609@users.noreply.github.com> Date: Fri, 17 Apr 2026 11:15:59 +0100 Subject: [PATCH 05/11] Improve large-tree test --- tests/testthat/test-large-trees.R | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/testthat/test-large-trees.R b/tests/testthat/test-large-trees.R index e0d59f7e..fbcf36cc 100644 --- a/tests/testthat/test-large-trees.R +++ b/tests/testthat/test-large-trees.R @@ -1,9 +1,8 @@ -test_that("Known-answer large-tree near-neighbours (4000 tips)", { - skip_on_cran() +test_that("Known-answer large-tree near-neighbours (>2048 tips)", { # Similar deterministic trees exercise shortcut paths and run quickly. - t1 <- as.phylo(0, 4000) - t2 <- as.phylo(1, 4000) + t1 <- as.phylo(0, 2050) + t2 <- as.phylo(1, 2050) trees <- structure(list(t1, t2), class = "multiPhylo") # Known answer for adjacent `as.phylo()` trees: one non-shared split per tree. @@ -13,11 +12,12 @@ test_that("Known-answer large-tree near-neighbours (4000 tips)", { # Other large-tree metrics should be finite and non-negative. cid <- ClusteringInfoDistance(t1, t2) + expect_equal(cid, 0.01409, tolerance = 1e-5) msd <- MatchingSplitDistance(t1, t2) + expect_equal(msd, 2) irf <- InfoRobinsonFoulds(t1, t2) - expect_true(all(is.finite(c(cid, msd, irf)))) - expect_true(all(c(cid, msd, irf) >= 0)) - + expect_equal(irf, 23.999, tolerance = 1e-4) + # Batch and pairwise paths must agree. expect_equal(unname(as.matrix(ClusteringInfoDistance(trees))[2, 1]), unname(cid), tolerance = 1e-8) From a97708d4ec3be3c74374d03af792a23001e30c63 Mon Sep 17 00:00:00 2001 From: R script <1695515+ms609@users.noreply.github.com> Date: Fri, 17 Apr 2026 11:33:26 +0100 Subject: [PATCH 06/11] Guard IC products and keep int32 arithmetic Add an ASSERT in add_ic_element() to ensure nTip stays within the safe int32 multiplication range, and keep numerator / denominator as int32 products. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/tree_distances.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/tree_distances.h b/src/tree_distances.h index 383a91d9..ffa11cbf 100644 --- a/src/tree_distances.h +++ b/src/tree_distances.h @@ -30,9 +30,12 @@ namespace TreeDist { const split_int n_tips, const double lg2_n) noexcept { if (nkK && nk && nK) { + ASSERT(n_tips > 0 && "n_tips must be positive"); + ASSERT(n_tips <= (std::numeric_limits::max)() / n_tips && + "nTip too large for int32 products in add_ic_element"); assert(!(nkK == nk && nkK == nK && nkK << 1 == n_tips)); - const int64_t numerator = static_cast(nkK) * n_tips; - const int64_t denominator = static_cast(nk) * nK; + const int32 numerator = nkK * n_tips; + const int32 denominator = nk * nK; if (numerator != denominator) { ic_sum += nkK * (lg2_lookup(nkK) + lg2_n - lg2_lookup(nk) - lg2_lookup(nK)); From 2dcf9a603ee35ebd2da6ead5fc1a08673d10b6dc Mon Sep 17 00:00:00 2001 From: R script <1695515+ms609@users.noreply.github.com> Date: Fri, 17 Apr 2026 11:49:20 +0100 Subject: [PATCH 07/11] Simplify cpp_max_tips to int-based limits Use int-native constexpr limits with static_assert fit checks in cpp_max_tips(), removing unnecessary int64_t intermediate types while preserving safety. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/tree_distance_functions.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/tree_distance_functions.cpp b/src/tree_distance_functions.cpp index 7e5d892a..ecc98e4e 100644 --- a/src/tree_distance_functions.cpp +++ b/src/tree_distance_functions.cpp @@ -51,9 +51,16 @@ double cpp_mci_impl_score(const Rcpp::RawMatrix& x, // [[Rcpp::export]] int cpp_max_tips() { - constexpr int64_t split_int_limit = - static_cast((std::numeric_limits::max)()); - constexpr int64_t int32_limit = - static_cast((std::numeric_limits::max)()); - return static_cast(std::min(split_int_limit, int32_limit)); + constexpr auto split_int_limit = + (std::numeric_limits::max)(); + constexpr auto int32_limit = (std::numeric_limits::max)(); + constexpr int int_limit = (std::numeric_limits::max)(); + + static_assert(split_int_limit <= int_limit, + "split_int max must fit in int"); + static_assert(int32_limit <= int_limit, + "int32 max must fit in int"); + + return std::min(static_cast(split_int_limit), + static_cast(int32_limit)); } From d965ab16517555e95ec4cd17a1eb3650dcf36c03 Mon Sep 17 00:00:00 2001 From: R script <1695515+ms609@users.noreply.github.com> Date: Fri, 17 Apr 2026 12:03:18 +0100 Subject: [PATCH 08/11] Make cpp_max_tips robust across int_fast32 widths Avoid narrowing overflow in cpp_max_tips() on platforms where int_fast32_t exceeds int width by clamping comparisons in native source types before any cast to int. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/tree_distance_functions.cpp | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/tree_distance_functions.cpp b/src/tree_distance_functions.cpp index ecc98e4e..16aab75f 100644 --- a/src/tree_distance_functions.cpp +++ b/src/tree_distance_functions.cpp @@ -51,16 +51,18 @@ double cpp_mci_impl_score(const Rcpp::RawMatrix& x, // [[Rcpp::export]] int cpp_max_tips() { + constexpr int int_limit = (std::numeric_limits::max)(); constexpr auto split_int_limit = (std::numeric_limits::max)(); - constexpr auto int32_limit = (std::numeric_limits::max)(); - constexpr int int_limit = (std::numeric_limits::max)(); - - static_assert(split_int_limit <= int_limit, - "split_int max must fit in int"); - static_assert(int32_limit <= int_limit, - "int32 max must fit in int"); + constexpr auto int32_limit = + (std::numeric_limits::max)(); - return std::min(static_cast(split_int_limit), - static_cast(int32_limit)); + int max_tips = int_limit; + if (split_int_limit < static_cast(max_tips)) { + max_tips = static_cast(split_int_limit); + } + if (int32_limit < static_cast(max_tips)) { + max_tips = static_cast(int32_limit); + } + return max_tips; } From c66ee0fb192b61e87503303bbf687fc2f6d1efa8 Mon Sep 17 00:00:00 2001 From: R script <1695515+ms609@users.noreply.github.com> Date: Fri, 17 Apr 2026 13:57:21 +0100 Subject: [PATCH 09/11] Restore fast lookup paths for SPI/MSI and cache max tip guard Use a local cached accessor in .CheckMaxTips() to avoid repeated cpp_max_tips() calls. Re-introduce direct table-indexed fast paths for SPI/MSI scoring when n_tips is within lookup-table range, while retaining safe lookup fallbacks for larger trees. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- R/tree_distance_utilities.R | 14 ++++- src/pairwise_distances.cpp | 95 ++++++++++++++++++++--------- src/tree_distances.cpp | 117 +++++++++++++++++++++++++----------- src/tree_distances.h | 103 ++++++++++++++++++++++++------- 4 files changed, 245 insertions(+), 84 deletions(-) diff --git a/R/tree_distance_utilities.R b/R/tree_distance_utilities.R index 3d02258e..d3e54518 100644 --- a/R/tree_distance_utilities.R +++ b/R/tree_distance_utilities.R @@ -1,10 +1,20 @@ +.CompiledTipLimit <- local({ + cached <- NA_integer_ + function() { + if (is.na(cached) || cached <= 0L) { + cached <<- cpp_max_tips() + } + cached + } +}) + .CheckMaxTips <- function(nTip, context = "") { if (is.na(nTip)) { return(invisible(NULL)) } - # Global limit from C++ integer types (not TreeTools stack thresholds). - maxTips <- cpp_max_tips() + # Compiled limit from C++ integer types (not TreeTools stack thresholds). + maxTips <- .CompiledTipLimit() if (nTip > maxTips) { suffix <- if (!nzchar(context)) "." else paste0(" for ", context, ".") stop("Trees with > ", maxTips, " tips are not yet supported", suffix) diff --git a/src/pairwise_distances.cpp b/src/pairwise_distances.cpp index 16d40936..28486dd0 100644 --- a/src/pairwise_distances.cpp +++ b/src/pairwise_distances.cpp @@ -576,32 +576,56 @@ static double msi_score( ) { const split_int most_splits = std::max(a.n_splits, b.n_splits); if (most_splits == 0) return 0.0; + const bool use_lookup_table = TreeDist::can_use_lookup_table(n_tips); constexpr cost max_score = BIG; - const double max_possible = TreeDist::lg2_unrooted_lookup(n_tips) - - TreeDist::lg2_rooted_lookup(split_int((n_tips + 1) / 2)) - - TreeDist::lg2_rooted_lookup(split_int(n_tips / 2)); + const double max_possible = use_lookup_table + ? TreeDist::lg2_unrooted[n_tips] + - TreeDist::lg2_rooted[split_int((n_tips + 1) / 2)] + - TreeDist::lg2_rooted[split_int(n_tips / 2)] + : TreeDist::lg2_unrooted_lookup(n_tips) + - TreeDist::lg2_rooted_lookup(split_int((n_tips + 1) / 2)) + - TreeDist::lg2_rooted_lookup(split_int(n_tips / 2)); const double score_over_possible = static_cast(max_score) / max_possible; const double possible_over_score = max_possible / static_cast(max_score); scratch.score_pool.resize(most_splits); cost_matrix& score = scratch.score_pool; - for (split_int ai = 0; ai < a.n_splits; ++ai) { - for (split_int bi = 0; bi < b.n_splits; ++bi) { - split_int n_a_only = 0, n_a_and_b = 0, n_different = 0; - splitbit different; - for (split_int bin = 0; bin < a.n_bins; ++bin) { - different = a.state[ai][bin] ^ b.state[bi][bin]; - n_different += count_bits(different); - n_a_only += count_bits(a.state[ai][bin] & different); - n_a_and_b += count_bits(a.state[ai][bin] & ~different); + if (use_lookup_table) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { + split_int n_a_only = 0, n_a_and_b = 0, n_different = 0; + splitbit different; + for (split_int bin = 0; bin < a.n_bins; ++bin) { + different = a.state[ai][bin] ^ b.state[bi][bin]; + n_different += count_bits(different); + n_a_only += count_bits(a.state[ai][bin] & different); + n_a_and_b += count_bits(a.state[ai][bin] & ~different); + } + const split_int n_same = n_tips - n_different; + score(ai, bi) = cost(max_score - score_over_possible * + TreeDist::mmsi_score_table(n_same, n_a_and_b, n_different, n_a_only)); + } + score.padRowAfterCol(ai, b.n_splits, max_score); + } + } else { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { + split_int n_a_only = 0, n_a_and_b = 0, n_different = 0; + splitbit different; + for (split_int bin = 0; bin < a.n_bins; ++bin) { + different = a.state[ai][bin] ^ b.state[bi][bin]; + n_different += count_bits(different); + n_a_only += count_bits(a.state[ai][bin] & different); + n_a_and_b += count_bits(a.state[ai][bin] & ~different); + } + const split_int n_same = n_tips - n_different; + score(ai, bi) = cost(max_score - score_over_possible * + TreeDist::mmsi_score(n_same, n_a_and_b, n_different, n_a_only)); } - const split_int n_same = n_tips - n_different; - score(ai, bi) = cost(max_score - score_over_possible * - TreeDist::mmsi_score(n_same, n_a_and_b, n_different, n_a_only)); + score.padRowAfterCol(ai, b.n_splits, max_score); } - score.padRowAfterCol(ai, b.n_splits, max_score); } score.padAfterRow(a.n_splits, max_score); @@ -675,27 +699,44 @@ static double shared_phylo_score( ) { const split_int most_splits = std::max(a.n_splits, b.n_splits); if (most_splits == 0) return 0.0; + const bool use_lookup_table = TreeDist::can_use_lookup_table(n_tips); const split_int overlap_a = split_int(n_tips + 1) / 2; constexpr cost max_score = BIG; - const double best_overlap = TreeDist::one_overlap(overlap_a, n_tips / 2, n_tips); - const double max_possible = TreeDist::lg2_unrooted_lookup(n_tips) - - best_overlap; + const double best_overlap = use_lookup_table + ? TreeDist::one_overlap_table(overlap_a, n_tips / 2, n_tips) + : TreeDist::one_overlap(overlap_a, n_tips / 2, n_tips); + const double max_possible = (use_lookup_table + ? TreeDist::lg2_unrooted[n_tips] + : TreeDist::lg2_unrooted_lookup(n_tips)) - best_overlap; const double score_over_possible = static_cast(max_score) / max_possible; const double possible_over_score = max_possible / static_cast(max_score); scratch.score_pool.resize(most_splits); cost_matrix& score = scratch.score_pool; - for (split_int ai = 0; ai < a.n_splits; ++ai) { - for (split_int bi = 0; bi < b.n_splits; ++bi) { - const double spi = TreeDist::spi_overlap( - a.state[ai], b.state[bi], n_tips, - a.in_split[ai], b.in_split[bi], a.n_bins); - score(ai, bi) = (spi == 0.0) ? max_score - : cost((spi - best_overlap) * score_over_possible); + if (use_lookup_table) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { + const double spi = TreeDist::spi_overlap_table( + a.state[ai], b.state[bi], n_tips, + a.in_split[ai], b.in_split[bi], a.n_bins); + score(ai, bi) = (spi == 0.0) ? max_score + : cost((spi - best_overlap) * score_over_possible); + } + score.padRowAfterCol(ai, b.n_splits, max_score); + } + } else { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { + const double spi = TreeDist::spi_overlap( + a.state[ai], b.state[bi], n_tips, + a.in_split[ai], b.in_split[bi], a.n_bins); + score(ai, bi) = (spi == 0.0) ? max_score + : cost((spi - best_overlap) * score_over_possible); + } + score.padRowAfterCol(ai, b.n_splits, max_score); } - score.padRowAfterCol(ai, b.n_splits, max_score); } score.padAfterRow(a.n_splits, max_score); diff --git a/src/tree_distances.cpp b/src/tree_distances.cpp index 58d0c3a5..e8a4f471 100644 --- a/src/tree_distances.cpp +++ b/src/tree_distances.cpp @@ -341,10 +341,15 @@ inline List jaccard_similarity(const RawMatrix &x, const RawMatrix &y, List msi_distance(const RawMatrix &x, const RawMatrix &y, const int32 n_tips) { const SplitList a(x), b(y); const split_int most_splits = std::max(a.n_splits, b.n_splits); + const bool use_lookup_table = TreeDist::can_use_lookup_table(n_tips); constexpr cost max_score = BIG; - const double max_possible = TreeDist::lg2_unrooted_lookup(n_tips) - - TreeDist::lg2_rooted_lookup(split_int((n_tips + 1) / 2)) - - TreeDist::lg2_rooted_lookup(split_int(n_tips / 2)); + const double max_possible = use_lookup_table + ? TreeDist::lg2_unrooted[n_tips] - + TreeDist::lg2_rooted[split_int((n_tips + 1) / 2)] - + TreeDist::lg2_rooted[split_int(n_tips / 2)] + : TreeDist::lg2_unrooted_lookup(n_tips) - + TreeDist::lg2_rooted_lookup(split_int((n_tips + 1) / 2)) - + TreeDist::lg2_rooted_lookup(split_int(n_tips / 2)); const double score_over_possible = static_cast(max_score) / max_possible; const double possible_over_score = max_possible / max_score; @@ -352,27 +357,52 @@ List msi_distance(const RawMatrix &x, const RawMatrix &y, const int32 n_tips) { std::vector different(a.n_bins); - for (split_int ai = 0; ai < a.n_splits; ++ai) { - if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); - for (split_int bi = 0; bi < b.n_splits; ++bi) { - split_int - n_different = 0, - n_a_only = 0, - n_a_and_b = 0 - ; - for (split_int bin = 0; bin < a.n_bins; ++bin) { - different[bin] = a.state[ai][bin] ^ b.state[bi][bin]; - n_different += count_bits(different[bin]); - n_a_only += count_bits(a.state[ai][bin] & different[bin]); - n_a_and_b += count_bits(a.state[ai][bin] & ~different[bin]); + if (use_lookup_table) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); + for (split_int bi = 0; bi < b.n_splits; ++bi) { + split_int + n_different = 0, + n_a_only = 0, + n_a_and_b = 0 + ; + for (split_int bin = 0; bin < a.n_bins; ++bin) { + different[bin] = a.state[ai][bin] ^ b.state[bi][bin]; + n_different += count_bits(different[bin]); + n_a_only += count_bits(a.state[ai][bin] & different[bin]); + n_a_and_b += count_bits(a.state[ai][bin] & ~different[bin]); + } + const split_int n_same = n_tips - n_different; + + score(ai, bi) = cost(max_score - + (score_over_possible * + TreeDist::mmsi_score_table(n_same, n_a_and_b, n_different, n_a_only))); } - const split_int n_same = n_tips - n_different; - - score(ai, bi) = cost(max_score - - (score_over_possible * - TreeDist::mmsi_score(n_same, n_a_and_b, n_different, n_a_only))); + score.padRowAfterCol(ai, b.n_splits, max_score); + } + } else { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); + for (split_int bi = 0; bi < b.n_splits; ++bi) { + split_int + n_different = 0, + n_a_only = 0, + n_a_and_b = 0 + ; + for (split_int bin = 0; bin < a.n_bins; ++bin) { + different[bin] = a.state[ai][bin] ^ b.state[bi][bin]; + n_different += count_bits(different[bin]); + n_a_only += count_bits(a.state[ai][bin] & different[bin]); + n_a_and_b += count_bits(a.state[ai][bin] & ~different[bin]); + } + const split_int n_same = n_tips - n_different; + + score(ai, bi) = cost(max_score - + (score_over_possible * + TreeDist::mmsi_score(n_same, n_a_and_b, n_different, n_a_only))); + } + score.padRowAfterCol(ai, b.n_splits, max_score); } - score.padRowAfterCol(ai, b.n_splits, max_score); } score.padAfterRow(a.n_splits, max_score); @@ -575,10 +605,14 @@ inline List shared_phylo (const RawMatrix &x, const RawMatrix &y, const SplitList a(x), b(y); const split_int most_splits = std::max(a.n_splits, b.n_splits); const split_int overlap_a = split_int(n_tips + 1) / 2; // avoids promotion to int + const bool use_lookup_table = TreeDist::can_use_lookup_table(n_tips); constexpr cost max_score = BIG; - const double lg2_unrooted_n = TreeDist::lg2_unrooted_lookup(n_tips); - const double best_overlap = TreeDist::one_overlap(overlap_a, n_tips / 2, n_tips); + const double lg2_unrooted_n = use_lookup_table ? TreeDist::lg2_unrooted[n_tips] : + TreeDist::lg2_unrooted_lookup(n_tips); + const double best_overlap = use_lookup_table + ? TreeDist::one_overlap_table(overlap_a, n_tips / 2, n_tips) + : TreeDist::one_overlap(overlap_a, n_tips / 2, n_tips); const double max_possible = lg2_unrooted_n - best_overlap; const double score_over_possible = max_score / max_possible; const double possible_over_score = max_possible / max_score; @@ -587,17 +621,32 @@ inline List shared_phylo (const RawMatrix &x, const RawMatrix &y, // In/out direction [i.e. 1/0 bit] is arbitrary. cost_matrix score(most_splits); - for (split_int ai = a.n_splits; ai--; ) { - if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); - for (split_int bi = b.n_splits; bi--; ) { - const double spi_over = TreeDist::spi_overlap( - a.state[ai], b.state[bi], n_tips, a.in_split[ai], b.in_split[bi], - a.n_bins); - - score(ai, bi) = spi_over == 0 ? max_score : - cost((spi_over - best_overlap) * score_over_possible); + if (use_lookup_table) { + for (split_int ai = a.n_splits; ai--; ) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); + for (split_int bi = b.n_splits; bi--; ) { + const double spi_over = TreeDist::spi_overlap_table( + a.state[ai], b.state[bi], n_tips, a.in_split[ai], b.in_split[bi], + a.n_bins); + + score(ai, bi) = spi_over == 0 ? max_score : + cost((spi_over - best_overlap) * score_over_possible); + } + score.padRowAfterCol(ai, b.n_splits, max_score); + } + } else { + for (split_int ai = a.n_splits; ai--; ) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); + for (split_int bi = b.n_splits; bi--; ) { + const double spi_over = TreeDist::spi_overlap( + a.state[ai], b.state[bi], n_tips, a.in_split[ai], b.in_split[bi], + a.n_bins); + + score(ai, bi) = spi_over == 0 ? max_score : + cost((spi_over - best_overlap) * score_over_possible); + } + score.padRowAfterCol(ai, b.n_splits, max_score); } - score.padRowAfterCol(ai, b.n_splits, max_score); } score.padAfterRow(a.n_splits, max_score); diff --git a/src/tree_distances.h b/src/tree_distances.h index ffa11cbf..bead1aaf 100644 --- a/src/tree_distances.h +++ b/src/tree_distances.h @@ -44,13 +44,38 @@ namespace TreeDist { } + [[nodiscard]] inline bool can_use_lookup_table(const split_int n_tips) noexcept { + return n_tips <= static_cast(SL_MAX_TIPS); + } + // Returns lg2_unrooted[x] - lg2_trees_matching_split(y, x - y) + [[nodiscard]] inline double mmsi_pair_score_table(const split_int x, + const split_int y) noexcept { + ASSERT(can_use_lookup_table(x)); + return lg2_unrooted[x] - (lg2_rooted[y] + lg2_rooted[x - y]); + } + [[nodiscard]] inline double mmsi_pair_score(const split_int x, const split_int y) noexcept { return lg2_unrooted_lookup(x) - (lg2_rooted_lookup(y) + lg2_rooted_lookup(x - y)); } + [[nodiscard]] inline double mmsi_score_table(const split_int n_same, + const split_int n_a_and_b, + const split_int n_different, + const split_int n_a_only) noexcept { + if (n_same == 0 || n_same == n_a_and_b) + return mmsi_pair_score_table(n_different, n_a_only); + if (n_different == 0 || n_different == n_a_only) + return mmsi_pair_score_table(n_same, n_a_and_b); + + const double score1 = mmsi_pair_score_table(n_same, n_a_and_b), + score2 = mmsi_pair_score_table(n_different, n_a_only); + + return (score1 > score2) ? score1 : score2; + } + [[nodiscard]] inline double mmsi_score(const split_int n_same, const split_int n_a_and_b, const split_int n_different, @@ -59,27 +84,49 @@ namespace TreeDist { return mmsi_pair_score(n_different, n_a_only); if (n_different == 0 || n_different == n_a_only) return mmsi_pair_score(n_same, n_a_and_b); - - const double - score1 = mmsi_pair_score(n_same, n_a_and_b), - score2 = mmsi_pair_score(n_different, n_a_only); - + + const double score1 = mmsi_pair_score(n_same, n_a_and_b), + score2 = mmsi_pair_score(n_different, n_a_only); + return (score1 > score2) ? score1 : score2; } + [[nodiscard]] inline double one_overlap_table(const split_int a, const split_int b, + const split_int n) noexcept { + ASSERT(can_use_lookup_table(n)); + if (a == b) { + return lg2_rooted[a] + lg2_rooted[n - a]; + } + const split_int lo = (a < b) ? a : b; + const split_int hi = (a < b) ? b : a; + return lg2_rooted[hi] + lg2_rooted[n - lo] - lg2_rooted[hi - lo + 1]; + } -[[nodiscard]] inline double one_overlap(const split_int a, const split_int b, - const split_int n) noexcept { + [[nodiscard]] inline double one_overlap(const split_int a, const split_int b, + const split_int n) noexcept { if (a == b) { return lg2_rooted_lookup(a) + lg2_rooted_lookup(n - a); } - // Unify ab via lo/hi: removes an unpredictable branch. const split_int lo = (a < b) ? a : b; const split_int hi = (a < b) ? b : a; return lg2_rooted_lookup(hi) + lg2_rooted_lookup(n - lo) - lg2_rooted_lookup(hi - lo + 1); } - + + [[nodiscard]] inline double one_overlap_notb_table(const split_int a, + const split_int n_minus_b, + const split_int n) noexcept { + ASSERT(can_use_lookup_table(n)); + const split_int b = n - n_minus_b; + if (a == b) { + return lg2_rooted[b] + lg2_rooted[n_minus_b]; + } else if (a < b) { + return lg2_rooted[b] + lg2_rooted[n - a] - lg2_rooted[b - a + 1]; + } else { + return lg2_rooted[a] + lg2_rooted[n_minus_b] - lg2_rooted[a - b + 1]; + } + } + [[nodiscard]] inline double one_overlap_notb(const split_int a, const split_int n_minus_b, const split_int n) noexcept { @@ -95,35 +142,49 @@ namespace TreeDist { } } + [[nodiscard]] inline double spi_overlap_table(const splitbit* a_state, + const splitbit* b_state, + const split_int n_tips, + const split_int in_a, + const split_int in_b, + const split_int n_bins) noexcept { + ASSERT(can_use_lookup_table(n_tips)); + split_int n_ab = 0; + for (split_int bin = 0; bin < n_bins; ++bin) { + n_ab += TreeTools::count_bits(a_state[bin] & b_state[bin]); + } + + if (n_ab == 0) { + return one_overlap_notb_table(in_a, in_b, n_tips); + } + if (n_ab == in_b || n_ab == in_a) { + return one_overlap_table(in_a, in_b, n_tips); + } + if (in_a + in_b - n_ab == n_tips) { + return one_overlap_notb_table(in_a, in_b, n_tips); + } + + return 0.0; + } // Popcount-based: single pass over bins replaces 4 sequential boolean scans. // Counts n_ab = |A ∩ B| via hardware POPCNT, then derives all 4 Venn-diagram // region populations from arithmetic on n_ab, in_a, in_b, n_tips. [[nodiscard]] inline double spi_overlap(const splitbit* a_state, const splitbit* b_state, - const split_int n_tips, const split_int in_a, - const split_int in_b, const split_int n_bins) noexcept { - + const split_int n_tips, const split_int in_a, + const split_int in_b, const split_int n_bins) noexcept { split_int n_ab = 0; for (split_int bin = 0; bin < n_bins; ++bin) { n_ab += TreeTools::count_bits(a_state[bin] & b_state[bin]); } - // n_a_only = in_a - n_ab (tips in A but not B) - // n_b_only = in_b - n_ab (tips in B but not A) - // n_neither = n_tips - in_a - in_b + n_ab (tips in neither) - // - // Return 0 when all 4 regions are populated (the common case for - // unrelated splits). Otherwise return the appropriate one_overlap score. - if (n_ab == 0) { return one_overlap_notb(in_a, in_b, n_tips); } if (n_ab == in_b || n_ab == in_a) { - // B ⊆ A (n_b_only == 0) or A ⊆ B (n_a_only == 0) return one_overlap(in_a, in_b, n_tips); } if (in_a + in_b - n_ab == n_tips) { - // A ∪ B covers all tips (n_neither == 0) return one_overlap_notb(in_a, in_b, n_tips); } From 8c7a86cef1677de4fceb760c2bed3974717b3d56 Mon Sep 17 00:00:00 2001 From: R script <1695515+ms609@users.noreply.github.com> Date: Fri, 17 Apr 2026 15:09:28 +0100 Subject: [PATCH 10/11] Code coverage (claude sonnet) --- inst/include/TreeDist/mutual_clustering.h | 8 ++-- man/dot-TreeDistance.Rd | 2 +- src/tree_distances.cpp | 4 +- tests/testthat/test-large-trees.R | 37 +++++++++++++++++++ tests/testthat/test-tree_distance_nni.R | 15 ++++++++ tests/testthat/test-tree_distance_utilities.R | 1 + 6 files changed, 60 insertions(+), 7 deletions(-) diff --git a/inst/include/TreeDist/mutual_clustering.h b/inst/include/TreeDist/mutual_clustering.h index 99e9389d..470eb7df 100644 --- a/inst/include/TreeDist/mutual_clustering.h +++ b/inst/include/TreeDist/mutual_clustering.h @@ -48,8 +48,8 @@ namespace TreeDist { if (n_tips <= static_cast(SL_MAX_TIPS + 1)) { return lg2_unrooted[n_tips]; } - if (n_tips < 3) { - return 0.0; + if (n_tips < 3) { // LCOV_EXCL_START + return 0.0; // LCOV_EXCL_STOP } const double n = static_cast(n_tips); // log2((2n - 5)!!) = log2((2n - 4)!) - (n - 2) - log2((n - 2)!) @@ -61,8 +61,8 @@ namespace TreeDist { if (n_tips <= static_cast(SL_MAX_TIPS + 1)) { return lg2_rooted[n_tips]; } - if (n_tips < 2) { - return 0.0; + if (n_tips < 2) { // LCOV_EXCL_START + return 0.0; // LCOV_EXCL_STOP } const double n = static_cast(n_tips); // log2((2n - 3)!!) = log2((2n - 2)!) - (n - 1) - log2((n - 1)!) diff --git a/man/dot-TreeDistance.Rd b/man/dot-TreeDistance.Rd index a704ea42..5df43547 100644 --- a/man/dot-TreeDistance.Rd +++ b/man/dot-TreeDistance.Rd @@ -14,7 +14,7 @@ avoid crashes in C++.} Calculate distance between trees, or lists of trees } \seealso{ -\code{\link{CalculateTreeDistance}} +\code{\link[=CalculateTreeDistance]{CalculateTreeDistance()}} } \author{ \href{https://orcid.org/0000-0001-5660-1727}{Martin R. Smith} diff --git a/src/tree_distances.cpp b/src/tree_distances.cpp index e8a4f471..5d8b9067 100644 --- a/src/tree_distances.cpp +++ b/src/tree_distances.cpp @@ -38,13 +38,13 @@ namespace TreeDist { Rcpp::stop("Requested nTip = %d is invalid.", n); } - if (n > max_supported_tips) { + if (n > max_supported_tips) { // LCOV_EXCL_START Rcpp::stop( "Requested nTip = %d exceeds this TreeDist build limit (%d): " "this many tips are not yet supported.", n, static_cast(max_supported_tips) ); - } + } // LCOV_EXCL_STOP } diff --git a/tests/testthat/test-large-trees.R b/tests/testthat/test-large-trees.R index fbcf36cc..6e7e536b 100644 --- a/tests/testthat/test-large-trees.R +++ b/tests/testthat/test-large-trees.R @@ -26,3 +26,40 @@ test_that("Known-answer large-tree near-neighbours (>2048 tips)", { expect_equal(unname(as.matrix(InfoRobinsonFoulds(trees))[2, 1]), unname(irf), tolerance = 1e-8) }) + +test_that("Large-tree (>SL_MAX_TIPS) non-table paths: PID, MSID, Jaccard, MCI", { + # 2052 tips: SL_MAX_TIPS (2048) is exceeded, forcing fallback (non-table) + # scoring paths in shared_phylo, msi_distance and the batch equivalents. + # Split size can reach 2051 > SL_MAX_TIPS + 1, triggering lg2_rooted_lookup + # fallback. Near-neighbour trees share most splits so the LAP is tiny. + t1 <- as.phylo(0, 2052) + t2 <- as.phylo(1, 2052) + trees <- structure(list(t1, t2), class = "multiPhylo") + + pid <- PhylogeneticInfoDistance(t1, t2) + expect_type(pid, "double") + expect_true(is.finite(pid)) + expect_gte(pid, 0) + + msid <- MatchingSplitInfoDistance(t1, t2) + expect_type(msid, "double") + expect_true(is.finite(msid)) + expect_gte(msid, 0) + + # Batch and pairwise paths must agree for PID and MSID + expect_equal(unname(as.matrix(PhylogeneticInfoDistance(trees))[2, 1]), + unname(pid), tolerance = 1e-8) + expect_equal(unname(as.matrix(MatchingSplitInfoDistance(trees))[2, 1]), + unname(msid), tolerance = 1e-8) + + # NyeSimilarity exercises jaccard_similarity serial interrupt path + jac <- NyeSimilarity(t1, t2) + expect_type(jac, "double") + expect_true(is.finite(jac)) + expect_gte(jac, 0) + + # reportMatching = TRUE forces the serial mutual_clustering path with interrupt + cid_match <- ClusteringInfoDistance(t1, t2, reportMatching = TRUE) + expect_true(is.integer(attr(cid_match, "matching"))) + expect_gt(length(attr(cid_match, "matching")), 0) +}) diff --git a/tests/testthat/test-tree_distance_nni.R b/tests/testthat/test-tree_distance_nni.R index 46fd6361..7ad98fa4 100644 --- a/tests/testthat/test-tree_distance_nni.R +++ b/tests/testthat/test-tree_distance_nni.R @@ -199,3 +199,18 @@ test_that("NNIDiameter() is sane", { expect_equal(NNIDiameter(as.phylo(0:1, 6)), NNIDiameter(c(6, 6))) }) + +test_that("cpp_nni_distance C++ guards fire for out-of-range nTip", { + # .NNIDistSingle calls .CheckMaxTips first, so the C++ guards inside + # cpp_nni_distance are only reachable via a direct call. + tree1 <- PectinateTree(5) + tree2 <- BalancedTree(5) + edge1 <- Postorder(tree1$edge) + edge2 <- Postorder(tree2$edge) + # Too many tips: fires nTip[0] > NNI_MAX_TIPS in C++ + expect_error(cpp_nni_distance(edge1, edge2, 32769L), + "not yet supported for NNI") + # Negative nTip: fires nTip[0] < 0 in C++ + expect_error(cpp_nni_distance(edge1, edge2, -1L), + "invalid") +}) diff --git a/tests/testthat/test-tree_distance_utilities.R b/tests/testthat/test-tree_distance_utilities.R index b16c6296..12c82bdb 100644 --- a/tests/testthat/test-tree_distance_utilities.R +++ b/tests/testthat/test-tree_distance_utilities.R @@ -46,6 +46,7 @@ test_that("Tip-count guard is applied consistently", { splits8 <- unclass(as.Splits(BalancedTree(8))) expect_no_error(cpp_robinson_foulds_distance(splits8, splits8, 8L)) + expect_error(cpp_robinson_foulds_distance(splits8, splits8, -1L), "invalid") expect_no_error(cpp_robinson_foulds_info(splits8, splits8, 8L)) trees <- list(BalancedTree(8), PectinateTree(8)) From 49a37c8265fc4ae67eda57cd8b81541ecf2b0464 Mon Sep 17 00:00:00 2001 From: R script <1695515+ms609@users.noreply.github.com> Date: Fri, 17 Apr 2026 16:07:18 +0100 Subject: [PATCH 11/11] Skip slow tests; tighten failures --- tests/testthat/test-large-trees.R | 8 +++++--- tests/testthat/test-tree_distance_nni.R | 3 +-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/testthat/test-large-trees.R b/tests/testthat/test-large-trees.R index 6e7e536b..5706455a 100644 --- a/tests/testthat/test-large-trees.R +++ b/tests/testthat/test-large-trees.R @@ -39,12 +39,12 @@ test_that("Large-tree (>SL_MAX_TIPS) non-table paths: PID, MSID, Jaccard, MCI", pid <- PhylogeneticInfoDistance(t1, t2) expect_type(pid, "double") expect_true(is.finite(pid)) - expect_gte(pid, 0) + expect_gt(pid, 0) msid <- MatchingSplitInfoDistance(t1, t2) expect_type(msid, "double") expect_true(is.finite(msid)) - expect_gte(msid, 0) + expect_gt(msid, 0) # Batch and pairwise paths must agree for PID and MSID expect_equal(unname(as.matrix(PhylogeneticInfoDistance(trees))[2, 1]), @@ -56,8 +56,10 @@ test_that("Large-tree (>SL_MAX_TIPS) non-table paths: PID, MSID, Jaccard, MCI", jac <- NyeSimilarity(t1, t2) expect_type(jac, "double") expect_true(is.finite(jac)) - expect_gte(jac, 0) + expect_gt(jac, 0) + skip_on_cran() + skip_if_not(getOption("slowMode", FALSE)) # reportMatching = TRUE forces the serial mutual_clustering path with interrupt cid_match <- ClusteringInfoDistance(t1, t2, reportMatching = TRUE) expect_true(is.integer(attr(cid_match, "matching"))) diff --git a/tests/testthat/test-tree_distance_nni.R b/tests/testthat/test-tree_distance_nni.R index 7ad98fa4..c479f4bb 100644 --- a/tests/testthat/test-tree_distance_nni.R +++ b/tests/testthat/test-tree_distance_nni.R @@ -21,8 +21,7 @@ test_that("NNIDist() at max tips", { more <- maxTips + 1L expect_error(.NNIDistSingle(PectinateTree(more), BalancedTree(more), more), "not yet supported for NNI") - goingQuickly <- TRUE - skip_if(goingQuickly) + skip_if_not(getOption("slowMode", FALSE)) heapTips <- 16384 + 1 skip_if(maxTips < heapTips)