diff --git a/R/RcppExports.R b/R/RcppExports.R index 7cd2fcd0..f7aa3d6d 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -275,6 +275,10 @@ cpp_mci_impl_score <- function(x, y, n_tips) { .Call(`_TreeDist_cpp_mci_impl_score`, x, y, n_tips) } +cpp_max_tips <- function() { + .Call(`_TreeDist_cpp_max_tips`) +} + cpp_robinson_foulds_distance <- function(x, y, nTip) { .Call(`_TreeDist_cpp_robinson_foulds_distance`, x, y, nTip) } diff --git a/R/transfer_consensus.R b/R/transfer_consensus.R index 3410cf09..6cef4490 100644 --- a/R/transfer_consensus.R +++ b/R/transfer_consensus.R @@ -66,7 +66,7 @@ TransferConsensus <- function(trees, if (nTip < 4L) { return(StarTree(tipLabels)) } - if (nTip > 32767L) stop("This many tips are not (yet) supported.") + .CheckMaxTips(nTip) # Convert each tree to a raw split matrix (TreeTools C++ internally). # as.Splits() will error if a tree's tips don't match tipLabels. @@ -115,7 +115,7 @@ tc_profile <- function(trees, scale = TRUE, greedy = "best", tipLabels <- TipLabels(trees[[1]]) nTip <- length(tipLabels) if (nTip < 4L) stop("Need at least 4 tips for profiling.") - if (nTip > 32767L) stop("This many tips are not (yet) supported.") + .CheckMaxTips(nTip) splitsList <- lapply(trees, function(tr) unclass(as.Splits(tr, tipLabels))) diff --git a/R/tree_distance.R b/R/tree_distance.R index 11bf9e91..5ea3a9cb 100644 --- a/R/tree_distance.R +++ b/R/tree_distance.R @@ -149,7 +149,7 @@ GeneralizedRF <- function(splits1, splits2, nTip, PairScorer, } nTip <- length(tipLabels) if (nTip < 4) return(NULL) # nocov - if (nTip > 32767L) stop("This many tips are not (yet) supported.") + .CheckMaxTips(nTip) splits_list <- as.Splits(tree1, tipLabels = tipLabels) n_threads <- as.integer(getOption("mc.cores", 1L)) @@ -203,7 +203,7 @@ GeneralizedRF <- function(splits1, splits2, nTip, PairScorer, nTip <- length(tipLabels1) if (nTip < 4) return(NULL) - if (nTip > 32767L) stop("This many tips are not (yet) supported.") + .CheckMaxTips(nTip) splits1 <- as.Splits(tree1, tipLabels = tipLabels1) splits2 <- as.Splits(tree2, tipLabels = tipLabels1) # Use tipLabels1 to ensure order consistency diff --git a/R/tree_distance_mast.R b/R/tree_distance_mast.R index 5c069841..eae741ff 100644 --- a/R/tree_distance_mast.R +++ b/R/tree_distance_mast.R @@ -96,8 +96,9 @@ MASTSize <- function(tree1, tree2 = tree1, rooted = TRUE) { if (nrow(edge1) != nrow(edge2)) { stop("Both trees must contain the same number of edges.") } - if (nTip > 4096L) { - stop("Tree too large; please contact maintainer for advice.") + maxTips <- min(4096L, cpp_max_tips()) + if (nTip > maxTips) { + stop("Trees with > ", maxTips, " tips are not yet supported for MAST.") } cpp_mast(edge1 - 1L, Postorder(edge2) - 1L, nTip) } diff --git a/R/tree_distance_nni.R b/R/tree_distance_nni.R index 45854764..2d617a13 100644 --- a/R/tree_distance_nni.R +++ b/R/tree_distance_nni.R @@ -73,9 +73,7 @@ NNIDist <- function(tree1, tree2 = tree1) { #' @importFrom TreeTools Postorder RenumberTips #' @importFrom ape Nnode.phylo .NNIDistSingle <- function(tree1, tree2, nTip, ...) { - if (nTip > 32768L) { - stop("Cannot calculate NNI distance for trees with so many tips.") - } + .CheckMaxTips(nTip, "NNI") if (nrow(tree1[["edge"]]) != nrow(tree2[["edge"]])) { stop("Both trees must have the same number of edges. ", "Is one rooted and the other unrooted?") diff --git a/R/tree_distance_transfer.R b/R/tree_distance_transfer.R index cd0942fd..55dce2e2 100644 --- a/R/tree_distance_transfer.R +++ b/R/tree_distance_transfer.R @@ -168,7 +168,7 @@ TransferDistSplits <- function(splits1, splits2, if (is.null(tipLabels)) return(NULL) nTip <- length(tipLabels) if (nTip < 4L) return(NULL) - if (nTip > 32767L) stop("This many tips are not (yet) supported.") + .CheckMaxTips(nTip) # Check all trees share same tip set allLabels <- TipLabels(tree1) @@ -211,7 +211,7 @@ TransferDistSplits <- function(splits1, splits2, if (is.null(tipLabels)) return(NULL) nTip <- length(tipLabels) if (nTip < 4L) return(NULL) - if (nTip > 32767L) stop("This many tips are not (yet) supported.") + .CheckMaxTips(nTip) # Check all trees share same tip set allLabels1 <- TipLabels(trees1) diff --git a/R/tree_distance_utilities.R b/R/tree_distance_utilities.R index 01b0ddca..d3e54518 100644 --- a/R/tree_distance_utilities.R +++ b/R/tree_distance_utilities.R @@ -1,3 +1,32 @@ +.CompiledTipLimit <- local({ + cached <- NA_integer_ + function() { + if (is.na(cached) || cached <= 0L) { + cached <<- cpp_max_tips() + } + cached + } +}) + +.CheckMaxTips <- function(nTip, context = "") { + if (is.na(nTip)) { + return(invisible(NULL)) + } + + # Compiled limit from C++ integer types (not TreeTools stack thresholds). + maxTips <- .CompiledTipLimit() + if (nTip > maxTips) { + suffix <- if (!nzchar(context)) "." else paste0(" for ", context, ".") + stop("Trees with > ", maxTips, " tips are not yet supported", suffix) + } + + # NNI uses fixed-size lookup tables in C++. + if (identical(context, "NNI") && nTip > 32768L) { + stop("Trees with > 32768 tips are not yet supported for NNI.") + } + invisible(NULL) +} + #' Wrapper for tree distance calculations #' #' Calls tree distance functions from trees or lists of trees @@ -11,15 +40,6 @@ #' @importFrom TreeTools as.Splits TipLabels #' @importFrom utils combn #' @export -# Keep in sync with C++ guard: min(SL_MAX_TIPS, int16_t::max()). -.MaxSupportedTips <- 32767L - -.AssertNtipSupported <- function(nTip) { - if (!is.na(nTip) && nTip > .MaxSupportedTips) { - stop("This many tips are not (yet) supported.") - } -} - CalculateTreeDistance <- function(Func, tree1, tree2 = NULL, reportMatching = FALSE, ...) { supportedClasses <- c("phylo", "Splits") @@ -141,7 +161,7 @@ CalculateTreeDistance <- function(Func, tree1, tree2 = NULL, # Fast paths: use OpenMP batch functions when all trees share the same tip # set and no R-level cluster has been configured. Each branch mirrors the # generic path exactly but avoids per-pair R overhead. - .AssertNtipSupported(nTip) + .CheckMaxTips(nTip) if (!is.na(nTip) && is.null(cluster)) { .n_threads <- as.integer(getOption("mc.cores", 1L)) .batch_result <- if (identical(Func, MutualClusteringInfoSplits)) { @@ -242,7 +262,7 @@ CalculateTreeDistance <- function(Func, tree1, tree2 = NULL, #' @importFrom stats setNames .SplitDistanceManyMany <- function(Func, splits1, splits2, tipLabels, nTip = length(tipLabels), ...) { - .AssertNtipSupported(nTip) + .CheckMaxTips(nTip) if (is.na(nTip)) { tipLabels <- union(unlist(tipLabels, use.names = FALSE), unlist(TipLabels(splits2), use.names = FALSE)) @@ -331,7 +351,7 @@ CalculateTreeDistance <- function(Func, tree1, tree2 = NULL, #' @param checks Logical specifying whether to perform basic sanity checks to #' avoid crashes in C++. #' @keywords internal -#' @seealso [`CalculateTreeDistance`] +#' @seealso [`CalculateTreeDistance()`] #' @export .TreeDistance <- function(Func, tree1, tree2, checks = TRUE, ...) { single1 <- inherits(tree1, "phylo") @@ -413,7 +433,7 @@ CalculateTreeDistance <- function(Func, tree1, tree2 = NULL, if (ncol(x) != ncol(y)) { stop("Input splits must address same number of tips.") } - .AssertNtipSupported(nTip) + .CheckMaxTips(nTip) } .CheckLabelsSame <- function(labelList) { diff --git a/R/tree_information.R b/R/tree_information.R index 3462f53e..81672e50 100644 --- a/R/tree_information.R +++ b/R/tree_information.R @@ -390,9 +390,10 @@ consensus_info <- function(trees, phylo, p) { stop("p must be >= 0.5 in consensus_info()") } nTip <- NTip(trees[[1]]) - # CT_MAX_LEAVES = 16383 in information.h (lookup table size limit) - if (nTip > 16383L) { - stop("This many leaves are not yet supported") + # CT_MAX_LEAVES = 16383 in information.h (lookup-table size limit). + maxTips <- min(16383L, cpp_max_tips()) + if (nTip > maxTips) { + stop("Trees with > ", maxTips, " tips are not yet supported for consensus info.") } .Call(`_TreeDist_consensus_info`, trees, phylo, p) } diff --git a/inst/include/TreeDist/mutual_clustering.h b/inst/include/TreeDist/mutual_clustering.h index 437633f0..470eb7df 100644 --- a/inst/include/TreeDist/mutual_clustering.h +++ b/inst/include/TreeDist/mutual_clustering.h @@ -15,6 +15,8 @@ namespace TreeDist { + constexpr double LOG2_E = 1.4426950408889634; + // ---- Lookup tables (populated by init_lg2_tables) ---- // // lg2[i] = log2(i) for 0 <= i <= SL_MAX_TIPS @@ -22,7 +24,9 @@ namespace TreeDist { // lg2_unrooted[i] = log2((2i-5)!!) for i >= 3 // lg2_rooted = &lg2_unrooted[0] + 1 (so lg2_rooted[i] = lg2_unrooted[i+1]) // - // These are defined in mutual_clustering_impl.h. + // These are fast-path caches sized to TreeTools' stack threshold. + // For larger trees we fall back to on-the-fly computation. + // Definitions are in mutual_clustering_impl.h. extern double lg2[SL_MAX_TIPS + 1]; extern double lg2_double_factorial[SL_MAX_TIPS + SL_MAX_TIPS - 2]; @@ -33,15 +37,48 @@ namespace TreeDist { // computation. max_tips should be >= the largest tree size used. void init_lg2_tables(int max_tips); + [[nodiscard]] inline double lg2_lookup(split_int x) noexcept { + if (x <= static_cast(SL_MAX_TIPS)) { + return lg2[x]; + } + return std::log2(static_cast(x)); + } + + [[nodiscard]] inline double lg2_unrooted_lookup(split_int n_tips) noexcept { + if (n_tips <= static_cast(SL_MAX_TIPS + 1)) { + return lg2_unrooted[n_tips]; + } + if (n_tips < 3) { // LCOV_EXCL_START + return 0.0; // LCOV_EXCL_STOP + } + const double n = static_cast(n_tips); + // log2((2n - 5)!!) = log2((2n - 4)!) - (n - 2) - log2((n - 2)!) + return (std::lgamma((2.0 * n) - 3.0) - std::lgamma(n - 1.0)) * LOG2_E + - (n - 2.0); + } + + [[nodiscard]] inline double lg2_rooted_lookup(split_int n_tips) noexcept { + if (n_tips <= static_cast(SL_MAX_TIPS + 1)) { + return lg2_rooted[n_tips]; + } + if (n_tips < 2) { // LCOV_EXCL_START + return 0.0; // LCOV_EXCL_STOP + } + const double n = static_cast(n_tips); + // log2((2n - 3)!!) = log2((2n - 2)!) - (n - 1) - log2((n - 1)!) + return (std::lgamma((2.0 * n) - 1.0) - std::lgamma(n)) * LOG2_E + - (n - 1.0); + } + // ---- Inline helpers ---- // Information content of a perfectly-matching split pair. // ic_matching(a, b, n) = (a + b) * lg2[n] - a * lg2[a] - b * lg2[b] - [[nodiscard]] inline double ic_matching(int16 a, int16 b, - int16 n) noexcept { - const double lg2a = lg2[a]; - const double lg2b = lg2[b]; - const double lg2n = lg2[n]; + [[nodiscard]] inline double ic_matching(split_int a, split_int b, + split_int n) noexcept { + const double lg2a = lg2_lookup(a); + const double lg2b = lg2_lookup(b); + const double lg2n = lg2_lookup(n); return (a + b) * lg2n - a * lg2a - b * lg2b; } @@ -77,9 +114,9 @@ namespace TreeDist { // Implementation in mutual_clustering_impl.h. double mutual_clustering_score( - const splitbit* const* a_state, const int16* a_in, int16 a_n_splits, - const splitbit* const* b_state, const int16* b_in, int16 b_n_splits, - int16 n_bins, int32 n_tips, + const splitbit* const* a_state, const split_int* a_in, split_int a_n_splits, + const splitbit* const* b_state, const split_int* b_in, split_int b_n_splits, + split_int n_bins, int32 n_tips, LapScratch& scratch); } // namespace TreeDist diff --git a/inst/include/TreeDist/mutual_clustering_impl.h b/inst/include/TreeDist/mutual_clustering_impl.h index 7f98c8fb..884a18b7 100644 --- a/inst/include/TreeDist/mutual_clustering_impl.h +++ b/inst/include/TreeDist/mutual_clustering_impl.h @@ -68,17 +68,17 @@ void init_lg2_tables(int max_tips) { namespace detail { -static int16 find_exact_matches_raw( - const splitbit* const* a_state, const int16* /*a_in*/, int16 a_n, - const splitbit* const* b_state, const int16* /*b_in*/, int16 b_n, - int16 n_bins, int32 n_tips, - int16* a_match, int16* b_match) +static split_int find_exact_matches_raw( + const splitbit* const* a_state, const split_int* /*a_in*/, split_int a_n, + const splitbit* const* b_state, const split_int* /*b_in*/, split_int b_n, + split_int n_bins, int32 n_tips, + split_int* a_match, split_int* b_match) { - std::fill(a_match, a_match + a_n, int16(0)); - std::fill(b_match, b_match + b_n, int16(0)); + std::fill(a_match, a_match + a_n, split_int(0)); + std::fill(b_match, b_match + b_n, split_int(0)); if (a_n == 0 || b_n == 0) return 0; - const int16 last_bin = n_bins - 1; + const split_int last_bin = n_bins - 1; const splitbit last_mask = (n_tips % SL_BIN_SIZE == 0) ? ~splitbit(0) : (splitbit(1) << (n_tips % SL_BIN_SIZE)) - 1; @@ -87,17 +87,17 @@ static int16 find_exact_matches_raw( std::vector a_canon(static_cast(a_n) * n_bins); std::vector b_canon(static_cast(b_n) * n_bins); - for (int16 i = 0; i < a_n; ++i) { + for (split_int i = 0; i < a_n; ++i) { const bool flip = !(a_state[i][0] & 1); - for (int16 bin = 0; bin < n_bins; ++bin) { + for (split_int bin = 0; bin < n_bins; ++bin) { splitbit val = flip ? ~a_state[i][bin] : a_state[i][bin]; if (bin == last_bin) val &= last_mask; a_canon[i * n_bins + bin] = val; } } - for (int16 i = 0; i < b_n; ++i) { + for (split_int i = 0; i < b_n; ++i) { const bool flip = !(b_state[i][0] & 1); - for (int16 bin = 0; bin < n_bins; ++bin) { + for (split_int bin = 0; bin < n_bins; ++bin) { splitbit val = flip ? ~b_state[i][bin] : b_state[i][bin]; if (bin == last_bin) val &= last_mask; b_canon[i * n_bins + bin] = val; @@ -105,8 +105,8 @@ static int16 find_exact_matches_raw( } // Sort index arrays by canonical form - auto canon_less = [&](const splitbit* canon, int16 n_b, int16 i, int16 j) { - for (int16 bin = 0; bin < n_b; ++bin) { + auto canon_less = [&](const splitbit* canon, split_int n_b, split_int i, split_int j) { + for (split_int bin = 0; bin < n_b; ++bin) { const splitbit vi = canon[i * n_b + bin]; const splitbit vj = canon[j * n_b + bin]; if (vi < vj) return true; @@ -115,28 +115,28 @@ static int16 find_exact_matches_raw( return false; }; - std::vector a_order(a_n), b_order(b_n); - std::iota(a_order.begin(), a_order.end(), int16(0)); - std::iota(b_order.begin(), b_order.end(), int16(0)); + std::vector a_order(a_n), b_order(b_n); + std::iota(a_order.begin(), a_order.end(), split_int(0)); + std::iota(b_order.begin(), b_order.end(), split_int(0)); std::sort(a_order.begin(), a_order.end(), - [&](int16 i, int16 j) { + [&](split_int i, split_int j) { return canon_less(a_canon.data(), n_bins, i, j); }); std::sort(b_order.begin(), b_order.end(), - [&](int16 i, int16 j) { + [&](split_int i, split_int j) { return canon_less(b_canon.data(), n_bins, i, j); }); // Merge-scan - int16 exact_n = 0; - int16 ai_pos = 0, bi_pos = 0; + split_int exact_n = 0; + split_int ai_pos = 0, bi_pos = 0; while (ai_pos < a_n && bi_pos < b_n) { - const int16 ai = a_order[ai_pos]; - const int16 bi = b_order[bi_pos]; + const split_int ai = a_order[ai_pos]; + const split_int bi = b_order[bi_pos]; int cmp = 0; - for (int16 bin = 0; bin < n_bins; ++bin) { + for (split_int bin = 0; bin < n_bins; ++bin) { const splitbit va = a_canon[ai * n_bins + bin]; const splitbit vb = b_canon[bi * n_bins + bin]; if (va < vb) { cmp = -1; break; } @@ -165,41 +165,41 @@ static int16 find_exact_matches_raw( // ---- MCI score implementation ---- double mutual_clustering_score( - const splitbit* const* a_state, const int16* a_in, int16 a_n_splits, - const splitbit* const* b_state, const int16* b_in, int16 b_n_splits, - int16 n_bins, int32 n_tips, + const splitbit* const* a_state, const split_int* a_in, split_int a_n_splits, + const splitbit* const* b_state, const split_int* b_in, split_int b_n_splits, + split_int n_bins, int32 n_tips, LapScratch& scratch) { if (a_n_splits == 0 || b_n_splits == 0 || n_tips == 0) return 0.0; - const int16 most_splits = std::max(a_n_splits, b_n_splits); + const split_int most_splits = std::max(a_n_splits, b_n_splits); const double n_tips_rcp = 1.0 / static_cast(n_tips); constexpr cost max_score = BIG; constexpr double over_max = 1.0 / static_cast(BIG); const double max_over_tips = static_cast(BIG) * n_tips_rcp; - const double lg2_n = lg2[n_tips]; + const double lg2_n = lg2_lookup(static_cast(n_tips)); // --- Phase 1: O(n log n) exact-match detection --- - std::vector a_match_buf(a_n_splits); - std::vector b_match_buf(b_n_splits); + std::vector a_match_buf(a_n_splits); + std::vector b_match_buf(b_n_splits); - const int16 exact_n = detail::find_exact_matches_raw( + const split_int exact_n = detail::find_exact_matches_raw( a_state, a_in, a_n_splits, b_state, b_in, b_n_splits, n_bins, n_tips, a_match_buf.data(), b_match_buf.data()); - const int16* a_match = a_match_buf.data(); - const int16* b_match = b_match_buf.data(); + const split_int* a_match = a_match_buf.data(); + const split_int* b_match = b_match_buf.data(); // Accumulate exact-match score double exact_score = 0.0; - for (int16 ai = 0; ai < a_n_splits; ++ai) { + for (split_int ai = 0; ai < a_n_splits; ++ai) { if (a_match[ai]) { - const int16 na = a_in[ai]; - const int16 nA = static_cast(n_tips - na); - exact_score += ic_matching(na, nA, static_cast(n_tips)); + const split_int na = a_in[ai]; + const split_int nA = static_cast(n_tips - na); + exact_score += ic_matching(na, nA, static_cast(n_tips)); } } @@ -209,58 +209,58 @@ double mutual_clustering_score( } // --- Phase 2: fill cost matrix for unmatched splits only (O(k²)) --- - const int16 lap_n = most_splits - exact_n; + const split_int lap_n = most_splits - exact_n; - std::vector a_unmatch, b_unmatch; + std::vector a_unmatch, b_unmatch; a_unmatch.reserve(lap_n); b_unmatch.reserve(lap_n); - for (int16 ai = 0; ai < a_n_splits; ++ai) { + for (split_int ai = 0; ai < a_n_splits; ++ai) { if (!a_match[ai]) a_unmatch.push_back(ai); } - for (int16 bi = 0; bi < b_n_splits; ++bi) { + for (split_int bi = 0; bi < b_n_splits; ++bi) { if (!b_match[bi]) b_unmatch.push_back(bi); } scratch.score_pool.resize(lap_n); CostMatrix& score = scratch.score_pool; - const int16 a_unmatched_n = static_cast(a_unmatch.size()); - const int16 b_unmatched_n = static_cast(b_unmatch.size()); + const split_int a_unmatched_n = static_cast(a_unmatch.size()); + const split_int b_unmatched_n = static_cast(b_unmatch.size()); - for (int16 a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { - const int16 ai = a_unmatch[a_pos]; - const int16 na = a_in[ai]; - const int16 nA = static_cast(n_tips - na); + for (split_int a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { + const split_int ai = a_unmatch[a_pos]; + const split_int na = a_in[ai]; + const split_int nA = static_cast(n_tips - na); const splitbit* a_row = a_state[ai]; - const double offset_a = lg2_n - lg2[na]; - const double offset_A = lg2_n - lg2[nA]; + const double offset_a = lg2_n - lg2_lookup(na); + const double offset_A = lg2_n - lg2_lookup(nA); - for (int16 b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { - const int16 bi = b_unmatch[b_pos]; + for (split_int b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { + const split_int bi = b_unmatch[b_pos]; const splitbit* b_row = b_state[bi]; - int16 a_and_b = 0; - for (int16 bin = 0; bin < n_bins; ++bin) { + split_int a_and_b = 0; + for (split_int bin = 0; bin < n_bins; ++bin) { a_and_b += TreeTools::count_bits(a_row[bin] & b_row[bin]); } - const int16 nb = b_in[bi]; - const int16 nB = static_cast(n_tips - nb); - const int16 a_and_B = na - a_and_b; - const int16 A_and_b = nb - a_and_b; - const int16 A_and_B = nA - A_and_b; + const split_int nb = b_in[bi]; + const split_int nB = static_cast(n_tips - nb); + const split_int a_and_B = na - a_and_b; + const split_int A_and_b = nb - a_and_b; + const split_int A_and_B = nA - A_and_b; if (a_and_b == A_and_b && a_and_b == a_and_B && a_and_b == A_and_B) { score(a_pos, b_pos) = max_score; } else { - const double lg2_nb = lg2[nb]; - const double lg2_nB = lg2[nB]; + const double lg2_nb = lg2_lookup(nb); + const double lg2_nB = lg2_lookup(nB); const double ic_sum = - a_and_b * (lg2[a_and_b] + offset_a - lg2_nb) + - a_and_B * (lg2[a_and_B] + offset_a - lg2_nB) + - A_and_b * (lg2[A_and_b] + offset_A - lg2_nb) + - A_and_B * (lg2[A_and_B] + offset_A - lg2_nB); + a_and_b * (lg2_lookup(a_and_b) + offset_a - lg2_nb) + + a_and_B * (lg2_lookup(a_and_B) + offset_a - lg2_nB) + + A_and_b * (lg2_lookup(A_and_b) + offset_A - lg2_nb) + + A_and_B * (lg2_lookup(A_and_B) + offset_A - lg2_nB); score(a_pos, b_pos) = max_score - static_cast(ic_sum * max_over_tips); } diff --git a/inst/include/TreeDist/types.h b/inst/include/TreeDist/types.h index c953989b..280f218d 100644 --- a/inst/include/TreeDist/types.h +++ b/inst/include/TreeDist/types.h @@ -16,6 +16,9 @@ namespace TreeDist { using int32 = int_fast32_t; using cost = int_fast64_t; + // Canonical type for split/tip/bin counters. + using split_int = int32; + using lap_dim = int; using lap_row = lap_dim; using lap_col = lap_dim; diff --git a/man/dot-TreeDistance.Rd b/man/dot-TreeDistance.Rd index a704ea42..5df43547 100644 --- a/man/dot-TreeDistance.Rd +++ b/man/dot-TreeDistance.Rd @@ -14,7 +14,7 @@ avoid crashes in C++.} Calculate distance between trees, or lists of trees } \seealso{ -\code{\link{CalculateTreeDistance}} +\code{\link[=CalculateTreeDistance]{CalculateTreeDistance()}} } \author{ \href{https://orcid.org/0000-0001-5660-1727}{Martin R. Smith} diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index be5b8963..d803056c 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -636,6 +636,16 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// cpp_max_tips +int cpp_max_tips(); +RcppExport SEXP _TreeDist_cpp_max_tips() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(cpp_max_tips()); + return rcpp_result_gen; +END_RCPP +} // cpp_robinson_foulds_distance List cpp_robinson_foulds_distance(const RawMatrix& x, const RawMatrix& y, const IntegerVector& nTip); RcppExport SEXP _TreeDist_cpp_robinson_foulds_distance(SEXP xSEXP, SEXP ySEXP, SEXP nTipSEXP) { @@ -780,6 +790,7 @@ static const R_CallMethodDef CallEntries[] = { {"_TreeDist_cpp_transfer_dist_all_pairs", (DL_FUNC) &_TreeDist_cpp_transfer_dist_all_pairs, 4}, {"_TreeDist_cpp_transfer_dist_cross_pairs", (DL_FUNC) &_TreeDist_cpp_transfer_dist_cross_pairs, 5}, {"_TreeDist_cpp_mci_impl_score", (DL_FUNC) &_TreeDist_cpp_mci_impl_score, 3}, + {"_TreeDist_cpp_max_tips", (DL_FUNC) &_TreeDist_cpp_max_tips, 0}, {"_TreeDist_cpp_robinson_foulds_distance", (DL_FUNC) &_TreeDist_cpp_robinson_foulds_distance, 3}, {"_TreeDist_cpp_robinson_foulds_info", (DL_FUNC) &_TreeDist_cpp_robinson_foulds_info, 3}, {"_TreeDist_cpp_matching_split_distance", (DL_FUNC) &_TreeDist_cpp_matching_split_distance, 3}, diff --git a/src/ints.h b/src/ints.h index 50fd3cc7..3d424c4c 100644 --- a/src/ints.h +++ b/src/ints.h @@ -6,6 +6,7 @@ // Re-export shared types to global scope for backward compatibility. using TreeDist::int16; using TreeDist::int32; +using TreeDist::split_int; // Types used only within TreeDist's own source. using uint32 = uint_fast32_t; diff --git a/src/nni_distance.cpp b/src/nni_distance.cpp index db3a5b3c..8712670e 100644 --- a/src/nni_distance.cpp +++ b/src/nni_distance.cpp @@ -288,8 +288,14 @@ grf_match nni_rf_matching ( IntegerVector cpp_nni_distance(const IntegerMatrix& edge1, const IntegerMatrix& edge2, const IntegerVector& nTip) { - - ASSERT(nTip[0] <= NNI_MAX_TIPS && "Cannot calculate NNI distance for trees with so many tips."); + + if (nTip[0] > NNI_MAX_TIPS) { + Rcpp::stop("Trees with > %d tips are not yet supported for NNI.", + NNI_MAX_TIPS); + } + if (nTip[0] < 0) { + Rcpp::stop("Requested nTip = %d is invalid for NNI.", nTip[0]); + } const int32_t n_tip = static_cast(nTip[0]); const int32_t node_0 = n_tip; const int32_t node_0_r = n_tip + 1; diff --git a/src/pairwise_distances.cpp b/src/pairwise_distances.cpp index c66b770f..28486dd0 100644 --- a/src/pairwise_distances.cpp +++ b/src/pairwise_distances.cpp @@ -36,10 +36,10 @@ using TreeTools::count_bits; struct MatchScratch { std::vector a_canon; std::vector b_canon; - std::vector a_order; - std::vector b_order; - std::vector a_match; - std::vector b_match; + std::vector a_order; + std::vector b_order; + std::vector a_match; + std::vector b_match; }; // --------------------------------------------------------------------------- @@ -59,19 +59,19 @@ struct MatchScratch { // a_match[ai] = bi+1 if split ai matched split bi, else 0. // b_match[bi] = ai+1 if split bi matched split ai, else 0. // --------------------------------------------------------------------------- -static int16 find_exact_matches( +static split_int find_exact_matches( const SplitList& a, const SplitList& b, const int32 n_tips, MatchScratch& scratch ) { - const int16 n_bins = a.n_bins; - const int16 last_bin = n_bins - 1; + const split_int n_bins = a.n_bins; + const split_int last_bin = n_bins - 1; const splitbit last_mask = (n_tips % SL_BIN_SIZE == 0) ? ~splitbit(0) : (splitbit(1) << (n_tips % SL_BIN_SIZE)) - 1; - const int16 a_n = a.n_splits; - const int16 b_n = b.n_splits; + const split_int a_n = a.n_splits; + const split_int b_n = b.n_splits; // Ensure buffers are large enough (grow lazily, never shrink) const size_t a_canon_sz = static_cast(a_n) * n_bins; @@ -83,10 +83,10 @@ static int16 find_exact_matches( if (scratch.a_match.size() < static_cast(a_n)) scratch.a_match.resize(a_n); if (scratch.b_match.size() < static_cast(b_n)) scratch.b_match.resize(b_n); - int16* a_match = scratch.a_match.data(); - int16* b_match = scratch.b_match.data(); - std::fill(a_match, a_match + a_n, int16(0)); - std::fill(b_match, b_match + b_n, int16(0)); + split_int* a_match = scratch.a_match.data(); + split_int* b_match = scratch.b_match.data(); + std::fill(a_match, a_match + a_n, split_int(0)); + std::fill(b_match, b_match + b_n, split_int(0)); if (a_n == 0 || b_n == 0) return 0; @@ -94,17 +94,17 @@ static int16 find_exact_matches( splitbit* b_canon = scratch.b_canon.data(); // --- 1. Compute canonical forms into flat buffers --- - for (int16 i = 0; i < a_n; ++i) { + for (split_int i = 0; i < a_n; ++i) { const bool flip = !(a.state[i][0] & 1); - for (int16 bin = 0; bin < n_bins; ++bin) { + for (split_int bin = 0; bin < n_bins; ++bin) { splitbit val = flip ? ~a.state[i][bin] : a.state[i][bin]; if (bin == last_bin) val &= last_mask; a_canon[i * n_bins + bin] = val; } } - for (int16 i = 0; i < b_n; ++i) { + for (split_int i = 0; i < b_n; ++i) { const bool flip = !(b.state[i][0] & 1); - for (int16 bin = 0; bin < n_bins; ++bin) { + for (split_int bin = 0; bin < n_bins; ++bin) { splitbit val = flip ? ~b.state[i][bin] : b.state[i][bin]; if (bin == last_bin) val &= last_mask; b_canon[i * n_bins + bin] = val; @@ -112,8 +112,8 @@ static int16 find_exact_matches( } // --- 2. Sort index arrays by canonical form --- - auto canon_less = [&](const splitbit* canon, int16 i, int16 j) { - for (int16 bin = 0; bin < n_bins; ++bin) { + auto canon_less = [&](const splitbit* canon, split_int i, split_int j) { + for (split_int bin = 0; bin < n_bins; ++bin) { const splitbit vi = canon[i * n_bins + bin]; const splitbit vj = canon[j * n_bins + bin]; if (vi < vj) return true; @@ -122,29 +122,29 @@ static int16 find_exact_matches( return false; // #nocov }; - int16* a_order = scratch.a_order.data(); - int16* b_order = scratch.b_order.data(); - std::iota(a_order, a_order + a_n, int16(0)); - std::iota(b_order, b_order + b_n, int16(0)); + split_int* a_order = scratch.a_order.data(); + split_int* b_order = scratch.b_order.data(); + std::iota(a_order, a_order + a_n, split_int(0)); + std::iota(b_order, b_order + b_n, split_int(0)); std::sort(a_order, a_order + a_n, - [&](int16 i, int16 j) { + [&](split_int i, split_int j) { return canon_less(a_canon, i, j); }); std::sort(b_order, b_order + b_n, - [&](int16 i, int16 j) { + [&](split_int i, split_int j) { return canon_less(b_canon, i, j); }); // --- 3. Merge-scan to find matches --- - int16 exact_n = 0; - int16 ai_pos = 0, bi_pos = 0; + split_int exact_n = 0; + split_int ai_pos = 0, bi_pos = 0; while (ai_pos < a_n && bi_pos < b_n) { - const int16 ai = a_order[ai_pos]; - const int16 bi = b_order[bi_pos]; + const split_int ai = a_order[ai_pos]; + const split_int bi = b_order[bi_pos]; int cmp = 0; - for (int16 bin = 0; bin < n_bins; ++bin) { + for (split_int bin = 0; bin < n_bins; ++bin) { const splitbit va = a_canon[ai * n_bins + bin]; const splitbit vb = b_canon[bi * n_bins + bin]; if (va < vb) { cmp = -1; break; } @@ -178,25 +178,25 @@ static double mutual_clustering_score( ) { if (a.n_splits == 0 || b.n_splits == 0 || n_tips == 0) return 0.0; - const int16 most_splits = std::max(a.n_splits, b.n_splits); + const split_int most_splits = std::max(a.n_splits, b.n_splits); const double n_tips_rcp = 1.0 / static_cast(n_tips); constexpr cost max_score = BIG; constexpr double over_max = 1.0 / static_cast(BIG); const double max_over_tips = static_cast(BIG) * n_tips_rcp; - const double lg2_n = lg2[n_tips]; + const double lg2_n = TreeDist::lg2_lookup(n_tips); // --- Phase 1: O(n log n) exact-match detection --- - const int16 exact_n = find_exact_matches(a, b, n_tips, mscratch); - const int16* a_match = mscratch.a_match.data(); - const int16* b_match = mscratch.b_match.data(); + const split_int exact_n = find_exact_matches(a, b, n_tips, mscratch); + const split_int* a_match = mscratch.a_match.data(); + const split_int* b_match = mscratch.b_match.data(); // Accumulate exact-match score double exact_score = 0.0; - for (int16 ai = 0; ai < a.n_splits; ++ai) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { if (a_match[ai]) { - const int16 na = a.in_split[ai]; - const int16 nA = n_tips - na; + const split_int na = a.in_split[ai]; + const split_int nA = n_tips - na; exact_score += TreeDist::ic_matching(na, nA, n_tips); } } @@ -207,58 +207,58 @@ static double mutual_clustering_score( } // --- Phase 2: fill cost matrix for unmatched splits only (O(k²)) --- - const int16 lap_n = most_splits - exact_n; + const split_int lap_n = most_splits - exact_n; // Build index maps for unmatched splits - std::vector a_unmatch, b_unmatch; + std::vector a_unmatch, b_unmatch; a_unmatch.reserve(lap_n); b_unmatch.reserve(lap_n); - for (int16 ai = 0; ai < a.n_splits; ++ai) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { if (!a_match[ai]) a_unmatch.push_back(ai); } - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { if (!b_match[bi]) b_unmatch.push_back(bi); } scratch.score_pool.resize(lap_n); cost_matrix& score = scratch.score_pool; - const int16 a_unmatched_n = static_cast(a_unmatch.size()); - const int16 b_unmatched_n = static_cast(b_unmatch.size()); + const split_int a_unmatched_n = static_cast(a_unmatch.size()); + const split_int b_unmatched_n = static_cast(b_unmatch.size()); - for (int16 a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { - const int16 ai = a_unmatch[a_pos]; - const int16 na = a.in_split[ai]; - const int16 nA = n_tips - na; + for (split_int a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { + const split_int ai = a_unmatch[a_pos]; + const split_int na = a.in_split[ai]; + const split_int nA = n_tips - na; const auto* a_row = a.state[ai]; - const double offset_a = lg2_n - lg2[na]; - const double offset_A = lg2_n - lg2[nA]; + const double offset_a = lg2_n - TreeDist::lg2_lookup(na); + const double offset_A = lg2_n - TreeDist::lg2_lookup(nA); - for (int16 b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { - const int16 bi = b_unmatch[b_pos]; + for (split_int b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { + const split_int bi = b_unmatch[b_pos]; const auto* b_row = b.state[bi]; - int16 a_and_b = 0; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + split_int a_and_b = 0; + for (split_int bin = 0; bin < a.n_bins; ++bin) { a_and_b += count_bits(a_row[bin] & b_row[bin]); } - const int16 nb = b.in_split[bi]; - const int16 nB = n_tips - nb; - const int16 a_and_B = na - a_and_b; - const int16 A_and_b = nb - a_and_b; - const int16 A_and_B = nA - A_and_b; + const split_int nb = b.in_split[bi]; + const split_int nB = n_tips - nb; + const split_int a_and_B = na - a_and_b; + const split_int A_and_b = nb - a_and_b; + const split_int A_and_B = nA - A_and_b; if (a_and_b == A_and_b && a_and_b == a_and_B && a_and_b == A_and_B) { score(a_pos, b_pos) = max_score; } else { - const double lg2_nb = lg2[nb]; - const double lg2_nB = lg2[nB]; + const double lg2_nb = TreeDist::lg2_lookup(nb); + const double lg2_nB = TreeDist::lg2_lookup(nB); const double ic_sum = - a_and_b * (lg2[a_and_b] + offset_a - lg2_nb) + - a_and_B * (lg2[a_and_B] + offset_a - lg2_nB) + - A_and_b * (lg2[A_and_b] + offset_A - lg2_nb) + - A_and_B * (lg2[A_and_B] + offset_A - lg2_nB); + a_and_b * (TreeDist::lg2_lookup(a_and_b) + offset_a - lg2_nb) + + a_and_B * (TreeDist::lg2_lookup(a_and_B) + offset_a - lg2_nB) + + A_and_b * (TreeDist::lg2_lookup(A_and_b) + offset_A - lg2_nb) + + A_and_B * (TreeDist::lg2_lookup(A_and_B) + offset_A - lg2_nB); score(a_pos, b_pos) = max_score - static_cast(ic_sum * max_over_tips); } } @@ -305,6 +305,7 @@ NumericVector cpp_mutual_clustering_all_pairs( const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); if (N < 2) return NumericVector(0); @@ -365,27 +366,27 @@ static double rf_info_score( const SplitList& a, const SplitList& b, const int32 n_tips, MatchScratch& mscratch ) { - const int16 a_n = a.n_splits; - const int16 b_n = b.n_splits; + const split_int a_n = a.n_splits; + const split_int b_n = b.n_splits; if (a_n == 0 || b_n == 0) return 0; // Use sort+merge to find exact matches in O(n log n) - const int16 exact_n = find_exact_matches(a, b, n_tips, mscratch); + const split_int exact_n = find_exact_matches(a, b, n_tips, mscratch); if (exact_n == 0) return 0; // Sum info contribution for each matched split in a - const int16* a_match = mscratch.a_match.data(); - const double lg2_unrooted_n = lg2_unrooted[n_tips]; + const split_int* a_match = mscratch.a_match.data(); + const double lg2_unrooted_n = TreeDist::lg2_unrooted_lookup(n_tips); double score = 0; - for (int16 ai = 0; ai < a_n; ++ai) { + for (split_int ai = 0; ai < a_n; ++ai) { if (a_match[ai] == 0) continue; - int16 leaves_in_split = 0; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + split_int leaves_in_split = 0; + for (split_int bin = 0; bin < a.n_bins; ++bin) { leaves_in_split += count_bits(a.state[ai][bin]); } score += lg2_unrooted_n - - lg2_rooted[leaves_in_split] - - lg2_rooted[n_tips - leaves_in_split]; + - TreeDist::lg2_rooted_lookup(leaves_in_split) + - TreeDist::lg2_rooted_lookup(n_tips - leaves_in_split); } return score; } @@ -397,6 +398,7 @@ NumericVector cpp_rf_info_all_pairs( const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); if (N < 2) return NumericVector(0); const int n_pairs = N * (N - 1) / 2; @@ -444,48 +446,48 @@ static double msd_score( const SplitList& a, const SplitList& b, const int32 n_tips, LapScratch& scratch, MatchScratch& mscratch ) { - const int16 most_splits = std::max(a.n_splits, b.n_splits); + const split_int most_splits = std::max(a.n_splits, b.n_splits); if (most_splits == 0) return 0.0; const bool a_has_more = (a.n_splits > b.n_splits); - const int16 a_extra = a_has_more ? most_splits - b.n_splits : 0; - const int16 b_extra = a_has_more ? 0 : most_splits - a.n_splits; - const int16 half_tips = n_tips / 2; + const split_int a_extra = a_has_more ? most_splits - b.n_splits : 0; + const split_int b_extra = a_has_more ? 0 : most_splits - a.n_splits; + const split_int half_tips = n_tips / 2; const cost max_score = BIG / most_splits; // --- Phase 1: O(n log n) exact-match detection --- - const int16 exact_n = find_exact_matches(a, b, n_tips, mscratch); - const int16* a_match = mscratch.a_match.data(); - const int16* b_match = mscratch.b_match.data(); + const split_int exact_n = find_exact_matches(a, b, n_tips, mscratch); + const split_int* a_match = mscratch.a_match.data(); + const split_int* b_match = mscratch.b_match.data(); if (exact_n == b.n_splits || exact_n == a.n_splits) { return 0.0; } // --- Phase 2: fill cost matrix for unmatched splits only --- - const int16 lap_n = most_splits - exact_n; + const split_int lap_n = most_splits - exact_n; - std::vector a_unmatch, b_unmatch; + std::vector a_unmatch, b_unmatch; a_unmatch.reserve(lap_n); b_unmatch.reserve(lap_n); - for (int16 ai = 0; ai < a.n_splits; ++ai) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { if (!a_match[ai]) a_unmatch.push_back(ai); } - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { if (!b_match[bi]) b_unmatch.push_back(bi); } scratch.score_pool.resize(lap_n); cost_matrix& score = scratch.score_pool; - const int16 a_unmatched_n = static_cast(a_unmatch.size()); - const int16 b_unmatched_n = static_cast(b_unmatch.size()); + const split_int a_unmatched_n = static_cast(a_unmatch.size()); + const split_int b_unmatched_n = static_cast(b_unmatch.size()); - for (int16 a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { - const int16 ai = a_unmatch[a_pos]; - for (int16 b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { - const int16 bi = b_unmatch[b_pos]; + for (split_int a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { + const split_int ai = a_unmatch[a_pos]; + for (split_int b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { + const split_int bi = b_unmatch[b_pos]; splitbit total = 0; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + for (split_int bin = 0; bin < a.n_bins; ++bin) { total += count_bits(a.state[ai][bin] ^ b.state[bi][bin]); } score(a_pos, b_pos) = static_cast( @@ -516,6 +518,7 @@ NumericVector cpp_msd_all_pairs( const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); if (N < 2) return NumericVector(0); const int n_pairs = N * (N - 1) / 2; @@ -571,34 +574,58 @@ static double msi_score( const SplitList& a, const SplitList& b, const int32 n_tips, LapScratch& scratch ) { - const int16 most_splits = std::max(a.n_splits, b.n_splits); + const split_int most_splits = std::max(a.n_splits, b.n_splits); if (most_splits == 0) return 0.0; + const bool use_lookup_table = TreeDist::can_use_lookup_table(n_tips); constexpr cost max_score = BIG; - const double max_possible = lg2_unrooted[n_tips] - - lg2_rooted[int16((n_tips + 1) / 2)] - - lg2_rooted[int16(n_tips / 2)]; + const double max_possible = use_lookup_table + ? TreeDist::lg2_unrooted[n_tips] + - TreeDist::lg2_rooted[split_int((n_tips + 1) / 2)] + - TreeDist::lg2_rooted[split_int(n_tips / 2)] + : TreeDist::lg2_unrooted_lookup(n_tips) + - TreeDist::lg2_rooted_lookup(split_int((n_tips + 1) / 2)) + - TreeDist::lg2_rooted_lookup(split_int(n_tips / 2)); const double score_over_possible = static_cast(max_score) / max_possible; const double possible_over_score = max_possible / static_cast(max_score); scratch.score_pool.resize(most_splits); cost_matrix& score = scratch.score_pool; - for (int16 ai = 0; ai < a.n_splits; ++ai) { - for (int16 bi = 0; bi < b.n_splits; ++bi) { - int16 n_a_only = 0, n_a_and_b = 0, n_different = 0; - splitbit different; - for (int16 bin = 0; bin < a.n_bins; ++bin) { - different = a.state[ai][bin] ^ b.state[bi][bin]; - n_different += count_bits(different); - n_a_only += count_bits(a.state[ai][bin] & different); - n_a_and_b += count_bits(a.state[ai][bin] & ~different); + if (use_lookup_table) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { + split_int n_a_only = 0, n_a_and_b = 0, n_different = 0; + splitbit different; + for (split_int bin = 0; bin < a.n_bins; ++bin) { + different = a.state[ai][bin] ^ b.state[bi][bin]; + n_different += count_bits(different); + n_a_only += count_bits(a.state[ai][bin] & different); + n_a_and_b += count_bits(a.state[ai][bin] & ~different); + } + const split_int n_same = n_tips - n_different; + score(ai, bi) = cost(max_score - score_over_possible * + TreeDist::mmsi_score_table(n_same, n_a_and_b, n_different, n_a_only)); + } + score.padRowAfterCol(ai, b.n_splits, max_score); + } + } else { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { + split_int n_a_only = 0, n_a_and_b = 0, n_different = 0; + splitbit different; + for (split_int bin = 0; bin < a.n_bins; ++bin) { + different = a.state[ai][bin] ^ b.state[bi][bin]; + n_different += count_bits(different); + n_a_only += count_bits(a.state[ai][bin] & different); + n_a_and_b += count_bits(a.state[ai][bin] & ~different); + } + const split_int n_same = n_tips - n_different; + score(ai, bi) = cost(max_score - score_over_possible * + TreeDist::mmsi_score(n_same, n_a_and_b, n_different, n_a_only)); } - const int16 n_same = n_tips - n_different; - score(ai, bi) = cost(max_score - score_over_possible * - TreeDist::mmsi_score(n_same, n_a_and_b, n_different, n_a_only)); + score.padRowAfterCol(ai, b.n_splits, max_score); } - score.padRowAfterCol(ai, b.n_splits, max_score); } score.padAfterRow(a.n_splits, max_score); @@ -618,6 +645,7 @@ NumericVector cpp_msi_all_pairs( const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); if (N < 2) return NumericVector(0); const int n_pairs = N * (N - 1) / 2; @@ -669,28 +697,46 @@ static double shared_phylo_score( const SplitList& a, const SplitList& b, const int32 n_tips, LapScratch& scratch ) { - const int16 most_splits = std::max(a.n_splits, b.n_splits); + const split_int most_splits = std::max(a.n_splits, b.n_splits); if (most_splits == 0) return 0.0; + const bool use_lookup_table = TreeDist::can_use_lookup_table(n_tips); - const int16 overlap_a = int16(n_tips + 1) / 2; + const split_int overlap_a = split_int(n_tips + 1) / 2; constexpr cost max_score = BIG; - const double best_overlap = TreeDist::one_overlap(overlap_a, n_tips / 2, n_tips); - const double max_possible = lg2_unrooted[n_tips] - best_overlap; + const double best_overlap = use_lookup_table + ? TreeDist::one_overlap_table(overlap_a, n_tips / 2, n_tips) + : TreeDist::one_overlap(overlap_a, n_tips / 2, n_tips); + const double max_possible = (use_lookup_table + ? TreeDist::lg2_unrooted[n_tips] + : TreeDist::lg2_unrooted_lookup(n_tips)) - best_overlap; const double score_over_possible = static_cast(max_score) / max_possible; const double possible_over_score = max_possible / static_cast(max_score); scratch.score_pool.resize(most_splits); cost_matrix& score = scratch.score_pool; - for (int16 ai = 0; ai < a.n_splits; ++ai) { - for (int16 bi = 0; bi < b.n_splits; ++bi) { - const double spi = TreeDist::spi_overlap( - a.state[ai], b.state[bi], n_tips, - a.in_split[ai], b.in_split[bi], a.n_bins); - score(ai, bi) = (spi == 0.0) ? max_score - : cost((spi - best_overlap) * score_over_possible); + if (use_lookup_table) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { + const double spi = TreeDist::spi_overlap_table( + a.state[ai], b.state[bi], n_tips, + a.in_split[ai], b.in_split[bi], a.n_bins); + score(ai, bi) = (spi == 0.0) ? max_score + : cost((spi - best_overlap) * score_over_possible); + } + score.padRowAfterCol(ai, b.n_splits, max_score); + } + } else { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { + const double spi = TreeDist::spi_overlap( + a.state[ai], b.state[bi], n_tips, + a.in_split[ai], b.in_split[bi], a.n_bins); + score(ai, bi) = (spi == 0.0) ? max_score + : cost((spi - best_overlap) * score_over_possible); + } + score.padRowAfterCol(ai, b.n_splits, max_score); } - score.padRowAfterCol(ai, b.n_splits, max_score); } score.padAfterRow(a.n_splits, max_score); @@ -710,6 +756,7 @@ NumericVector cpp_shared_phylo_all_pairs( const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); if (N < 2) return NumericVector(0); const int n_pairs = N * (N - 1) / 2; @@ -758,7 +805,7 @@ static double jaccard_score( const double exponent, const bool allow_conflict, LapScratch& scratch, MatchScratch& mscratch ) { - const int16 most_splits = std::max(a.n_splits, b.n_splits); + const split_int most_splits = std::max(a.n_splits, b.n_splits); if (most_splits == 0) return 0.0; constexpr cost max_score = BIG; @@ -767,7 +814,7 @@ static double jaccard_score( // --- Phase 1: O(n log n) exact-match detection --- // Only used when allow_conflict=true; otherwise the full LAP may reassign // non-matching splits to compatible (non-exact) partners. - int16 exact_n = 0; + split_int exact_n = 0; if (allow_conflict) { exact_n = find_exact_matches(a, b, n_tips, mscratch); } else { @@ -776,61 +823,61 @@ static double jaccard_score( mscratch.a_match.resize(a.n_splits); if (mscratch.b_match.size() < static_cast(b.n_splits)) mscratch.b_match.resize(b.n_splits); - std::fill(mscratch.a_match.data(), mscratch.a_match.data() + a.n_splits, int16(0)); - std::fill(mscratch.b_match.data(), mscratch.b_match.data() + b.n_splits, int16(0)); + std::fill(mscratch.a_match.data(), mscratch.a_match.data() + a.n_splits, split_int(0)); + std::fill(mscratch.b_match.data(), mscratch.b_match.data() + b.n_splits, split_int(0)); } - const int16* a_match = mscratch.a_match.data(); - const int16* b_match = mscratch.b_match.data(); + const split_int* a_match = mscratch.a_match.data(); + const split_int* b_match = mscratch.b_match.data(); if (exact_n == b.n_splits || exact_n == a.n_splits) { return static_cast(exact_n); } // --- Phase 2: fill cost matrix for unmatched splits only --- - const int16 lap_n = most_splits - exact_n; + const split_int lap_n = most_splits - exact_n; - std::vector a_unmatch, b_unmatch; + std::vector a_unmatch, b_unmatch; a_unmatch.reserve(lap_n); b_unmatch.reserve(lap_n); - for (int16 ai = 0; ai < a.n_splits; ++ai) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { if (!a_match[ai]) a_unmatch.push_back(ai); } - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { if (!b_match[bi]) b_unmatch.push_back(bi); } scratch.score_pool.resize(lap_n); cost_matrix& score = scratch.score_pool; - const int16 a_unmatched_n = static_cast(a_unmatch.size()); - const int16 b_unmatched_n = static_cast(b_unmatch.size()); + const split_int a_unmatched_n = static_cast(a_unmatch.size()); + const split_int b_unmatched_n = static_cast(b_unmatch.size()); - for (int16 a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { - const int16 ai = a_unmatch[a_pos]; - const int16 na = a.in_split[ai]; - const int16 nA = n_tips - na; + for (split_int a_pos = 0; a_pos < a_unmatched_n; ++a_pos) { + const split_int ai = a_unmatch[a_pos]; + const split_int na = a.in_split[ai]; + const split_int nA = n_tips - na; - for (int16 b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { - const int16 bi = b_unmatch[b_pos]; - int16 a_and_b = 0; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + for (split_int b_pos = 0; b_pos < b_unmatched_n; ++b_pos) { + const split_int bi = b_unmatch[b_pos]; + split_int a_and_b = 0; + for (split_int bin = 0; bin < a.n_bins; ++bin) { a_and_b += count_bits(a.state[ai][bin] & b.state[bi][bin]); } - const int16 nb = b.in_split[bi]; - const int16 nB = n_tips - nb; - const int16 a_and_B = na - a_and_b; - const int16 A_and_b = nb - a_and_b; - const int16 A_and_B = nB - a_and_B; + const split_int nb = b.in_split[bi]; + const split_int nB = n_tips - nb; + const split_int a_and_B = na - a_and_b; + const split_int A_and_b = nb - a_and_b; + const split_int A_and_B = nB - a_and_B; if (!allow_conflict && !( a_and_b == na || a_and_B == na || A_and_b == nA || A_and_B == nA)) { score(a_pos, b_pos) = max_score; } else { - const int16 A_or_b = n_tips - a_and_B; - const int16 a_or_B = n_tips - A_and_b; - const int16 a_or_b = n_tips - A_and_B; - const int16 A_or_B = n_tips - a_and_b; + const split_int A_or_b = n_tips - a_and_B; + const split_int a_or_B = n_tips - A_and_b; + const split_int a_or_b = n_tips - A_and_B; + const split_int A_or_B = n_tips - a_and_b; const double ars_ab = static_cast(a_and_b) / static_cast(a_or_b); const double ars_Ab = static_cast(A_and_b) / static_cast(A_or_b); const double ars_aB = static_cast(a_and_B) / static_cast(a_or_B); @@ -875,6 +922,7 @@ NumericVector cpp_jaccard_all_pairs( const bool allow_conflict = true, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); if (N < 2) return NumericVector(0); const int n_pairs = N * (N - 1) / 2; @@ -944,6 +992,7 @@ NumericMatrix cpp_mutual_clustering_cross_pairs( const List& splits_a, const List& splits_b, const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int nA = splits_a.size(); const int nB = splits_b.size(); if (nA == 0 || nB == 0) return NumericMatrix(nA, nB); @@ -987,6 +1036,7 @@ NumericMatrix cpp_rf_info_cross_pairs( const List& splits_a, const List& splits_b, const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int nA = splits_a.size(); const int nB = splits_b.size(); if (nA == 0 || nB == 0) return NumericMatrix(nA, nB); @@ -1027,6 +1077,7 @@ NumericMatrix cpp_msd_cross_pairs( const List& splits_a, const List& splits_b, const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int nA = splits_a.size(); const int nB = splits_b.size(); if (nA == 0 || nB == 0) return NumericMatrix(nA, nB); @@ -1070,6 +1121,7 @@ NumericMatrix cpp_msi_cross_pairs( const List& splits_a, const List& splits_b, const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int nA = splits_a.size(); const int nB = splits_b.size(); if (nA == 0 || nB == 0) return NumericMatrix(nA, nB); @@ -1110,6 +1162,7 @@ NumericMatrix cpp_shared_phylo_cross_pairs( const List& splits_a, const List& splits_b, const int n_tip, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int nA = splits_a.size(); const int nB = splits_b.size(); if (nA == 0 || nB == 0) return NumericMatrix(nA, nB); @@ -1153,6 +1206,7 @@ NumericMatrix cpp_jaccard_cross_pairs( const bool allow_conflict = true, const int n_threads = 1 ) { + TreeDist::check_ntip(static_cast(n_tip)); const int nA = splits_a.size(); const int nB = splits_b.size(); if (nA == 0 || nB == 0) return NumericMatrix(nA, nB); @@ -1206,6 +1260,7 @@ NumericVector cpp_clustering_entropy_batch( const List& splits_list, const int n_tip ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); NumericVector result(N); if (N == 0 || n_tip <= 0) return result; @@ -1216,7 +1271,7 @@ NumericVector cpp_clustering_entropy_batch( for (int i = 0; i < N; ++i) { SplitList sl(Rcpp::as(splits_list[i])); double total = 0.0; - for (int16 s = 0; s < sl.n_splits; ++s) { + for (split_int s = 0; s < sl.n_splits; ++s) { const int k = sl.in_split[s]; if (k <= 0 || k >= n_tip) continue; const double p = k * invN; @@ -1239,6 +1294,7 @@ NumericVector cpp_splitwise_info_batch( const List& splits_list, const int n_tip ) { + TreeDist::check_ntip(static_cast(n_tip)); const int N = splits_list.size(); NumericVector result(N); if (N == 0 || n_tip < 4) return result; @@ -1257,7 +1313,7 @@ NumericVector cpp_splitwise_info_batch( for (int i = 0; i < N; ++i) { SplitList sl(Rcpp::as(splits_list[i])); double total = 0.0; - for (int16 s = 0; s < sl.n_splits; ++s) { + for (split_int s = 0; s < sl.n_splits; ++s) { const int k = sl.in_split[s]; if (k < 2 || (n_tip - k) < 2) continue; total += l2u_n - l2r[k] - l2r[n_tip - k]; diff --git a/src/tree_distance_functions.cpp b/src/tree_distance_functions.cpp index 99347300..16aab75f 100644 --- a/src/tree_distance_functions.cpp +++ b/src/tree_distance_functions.cpp @@ -1,5 +1,7 @@ #include #include +#include +#include // Provide the MCI table definitions and implementation in this TU. #define TREEDIST_MCI_IMPLEMENTATION @@ -10,6 +12,7 @@ // Populate lookup tables at library load time. __attribute__((constructor)) void initialize_ldf() { + // Cache up to TreeTools' stack-threshold value. TreeDist::init_lg2_tables(SL_MAX_TIPS); } @@ -28,15 +31,15 @@ double cpp_mci_impl_score(const Rcpp::RawMatrix& x, // Build arrays matching the header's raw-pointer API types. std::vector a_ptrs(a.n_splits); std::vector b_ptrs(b.n_splits); - std::vector a_in(a.n_splits); - std::vector b_in(b.n_splits); - for (TreeDist::int16 i = 0; i < a.n_splits; ++i) { + std::vector a_in(a.n_splits); + std::vector b_in(b.n_splits); + for (TreeDist::int32 i = 0; i < a.n_splits; ++i) { a_ptrs[i] = a.state[i]; - a_in[i] = static_cast(a.in_split[i]); + a_in[i] = static_cast(a.in_split[i]); } - for (TreeDist::int16 i = 0; i < b.n_splits; ++i) { + for (TreeDist::int32 i = 0; i < b.n_splits; ++i) { b_ptrs[i] = b.state[i]; - b_in[i] = static_cast(b.in_split[i]); + b_in[i] = static_cast(b.in_split[i]); } return TreeDist::mutual_clustering_score( @@ -45,3 +48,21 @@ double cpp_mci_impl_score(const Rcpp::RawMatrix& x, a.n_bins, static_cast(n_tips), scratch); } + +// [[Rcpp::export]] +int cpp_max_tips() { + constexpr int int_limit = (std::numeric_limits::max)(); + constexpr auto split_int_limit = + (std::numeric_limits::max)(); + constexpr auto int32_limit = + (std::numeric_limits::max)(); + + int max_tips = int_limit; + if (split_int_limit < static_cast(max_tips)) { + max_tips = static_cast(split_int_limit); + } + if (int32_limit < static_cast(max_tips)) { + max_tips = static_cast(int32_limit); + } + return max_tips; +} diff --git a/src/tree_distances.cpp b/src/tree_distances.cpp index f79b48da..5d8b9067 100644 --- a/src/tree_distances.cpp +++ b/src/tree_distances.cpp @@ -29,17 +29,22 @@ namespace TreeDist { } } - void check_ntip(const double n) { - // SplitList dimensions are bounded by SL_MAX_TIPS, and current scoring - // paths use int16-sized counts internally. - static_assert(SL_MAX_TIPS <= std::numeric_limits::max(), - "SL_MAX_TIPS must fit in int32"); - constexpr int32 max_supported_tips = std::min( - int32(SL_MAX_TIPS), int32(std::numeric_limits::max()) - ); - if (n > max_supported_tips) { - Rcpp::stop("This many tips are not (yet) supported."); + void check_ntip(const int32 n) { + constexpr int64_t split_int_limit = + static_cast((std::numeric_limits::max)()); + constexpr int64_t max_supported_tips = split_int_limit; + + if (n < 0) { + Rcpp::stop("Requested nTip = %d is invalid.", n); } + + if (n > max_supported_tips) { // LCOV_EXCL_START + Rcpp::stop( + "Requested nTip = %d exceeds this TreeDist build limit (%d): " + "this many tips are not yet supported.", + n, static_cast(max_supported_tips) + ); + } // LCOV_EXCL_STOP } @@ -69,6 +74,7 @@ inline List robinson_foulds_distance(const RawMatrix &x, const RawMatrix &y, } for (int32 ai = a.n_splits; ai--; ) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); for (int32 bi = b.n_splits; bi--; ) { bool all_match = true; @@ -104,38 +110,39 @@ inline List robinson_foulds_info(const RawMatrix &x, const RawMatrix &y, const int32 n_tips) { const SplitList a(x), b(y); - const int16 last_bin = a.n_bins - 1; - const int16 unset_tips = (n_tips % SL_BIN_SIZE) ? + const split_int last_bin = a.n_bins - 1; + const split_int unset_tips = (n_tips % SL_BIN_SIZE) ? SL_BIN_SIZE - n_tips % SL_BIN_SIZE : 0; const splitbit unset_mask = ALL_ONES >> unset_tips; - const double lg2_unrooted_n = lg2_unrooted[n_tips]; + const double lg2_unrooted_n = TreeDist::lg2_unrooted_lookup(n_tips); double score = 0; grf_match matching(a.n_splits, NA_INTEGER); // Heap-backed scratch avoids large fixed-size stack allocation. std::vector b_complement(size_t(b.n_splits) * size_t(a.n_bins)); - for (int16 i = 0; i < b.n_splits; i++) { - for (int16 bin = 0; bin < last_bin; ++bin) { + for (split_int i = 0; i < b.n_splits; i++) { + for (split_int bin = 0; bin < last_bin; ++bin) { b_complement[size_t(i) * a.n_bins + bin] = ~b.state[i][bin]; } b_complement[size_t(i) * a.n_bins + last_bin] = b.state[i][last_bin] ^ unset_mask; } - for (int16 ai = 0; ai < a.n_splits; ++ai) { - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); + for (split_int bi = 0; bi < b.n_splits; ++bi) { bool all_match = true, all_complement = true; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + for (split_int bin = 0; bin < a.n_bins; ++bin) { if ((a.state[ai][bin] != b.state[bi][bin])) { all_match = false; break; } } if (!all_match) { - for (int16 bin = 0; bin < a.n_bins; ++bin) { + for (split_int bin = 0; bin < a.n_bins; ++bin) { if ((a.state[ai][bin] != b_complement[size_t(bi) * a.n_bins + bin])) { all_complement = false; break; @@ -143,13 +150,14 @@ inline List robinson_foulds_info(const RawMatrix &x, const RawMatrix &y, } } if (all_match || all_complement) { - int16 leaves_in_split = 0; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + split_int leaves_in_split = 0; + for (split_int bin = 0; bin < a.n_bins; ++bin) { leaves_in_split += count_bits(a.state[ai][bin]); } - score += lg2_unrooted_n - lg2_rooted[leaves_in_split] - - lg2_rooted[n_tips - leaves_in_split]; + score += lg2_unrooted_n - + TreeDist::lg2_rooted_lookup(leaves_in_split) - + TreeDist::lg2_rooted_lookup(n_tips - leaves_in_split); matching[ai] = bi + 1; break; /* Only one match possible per split */ @@ -167,27 +175,28 @@ inline List robinson_foulds_info(const RawMatrix &x, const RawMatrix &y, inline List matching_split_distance(const RawMatrix &x, const RawMatrix &y, const int32 n_tips) { const SplitList a(x), b(y); - const int16 most_splits = std::max(a.n_splits, b.n_splits); - const int16 split_diff = most_splits - std::min(a.n_splits, b.n_splits); - const int16 half_tips = n_tips / 2; + const split_int most_splits = std::max(a.n_splits, b.n_splits); + const split_int split_diff = most_splits - std::min(a.n_splits, b.n_splits); + const split_int half_tips = n_tips / 2; if (most_splits == 0) { return List::create(Named("score") = 0); } const cost max_score = BIG / most_splits; cost_matrix score(most_splits); - for (int16 ai = 0; ai < a.n_splits; ++ai) { - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); + for (split_int bi = 0; bi < b.n_splits; ++bi) { splitbit total = 0; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + for (split_int bin = 0; bin < a.n_bins; ++bin) { total += count_bits(a.state[ai][bin] ^ b.state[bi][bin]); } score(ai, bi) = total; } } - for (int16 ai = 0; ai < a.n_splits; ++ai) { - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { if (score(ai, bi) > half_tips) { score(ai, bi) = n_tips - score(ai, bi); } @@ -209,7 +218,7 @@ inline List matching_split_distance(const RawMatrix &x, const RawMatrix &y, std::vector final_matching; final_matching.reserve(a.n_splits); - for (int16 i = 0; i < a.n_splits; ++i) { + for (split_int i = 0; i < a.n_splits; ++i) { const int match = (rowsol[i] < b.n_splits) ? static_cast(rowsol[i]) + 1 : NA_INTEGER; @@ -225,7 +234,7 @@ inline List jaccard_similarity(const RawMatrix &x, const RawMatrix &y, const int32 n_tips, const NumericVector &k, const LogicalVector &allowConflict) { const SplitList a(x), b(y); - const int16 most_splits = std::max(a.n_splits, b.n_splits); + const split_int most_splits = std::max(a.n_splits, b.n_splits); constexpr cost max_score = BIG; constexpr double max_scoreL = max_score; @@ -235,20 +244,21 @@ inline List jaccard_similarity(const RawMatrix &x, const RawMatrix &y, cost_matrix score(most_splits); - for (int16 ai = 0; ai < a.n_splits; ++ai) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); - const int16 na = a.in_split[ai]; - const int16 nA = n_tips - na; + const split_int na = a.in_split[ai]; + const split_int nA = n_tips - na; - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { // x divides tips into a|A; y divides tips into b|B - int16 a_and_b = 0; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + split_int a_and_b = 0; + for (split_int bin = 0; bin < a.n_bins; ++bin) { a_and_b += count_bits(a.state[ai][bin] & b.state[bi][bin]); } - const int16 + const split_int nb = b.in_split[bi], nB = n_tips - nb, a_and_B = na - a_and_b, @@ -266,7 +276,7 @@ inline List jaccard_similarity(const RawMatrix &x, const RawMatrix &y, score(ai, bi) = max_score; /* Prohibited */ } else { - const int16 + const split_int A_or_b = n_tips - a_and_B, a_or_B = n_tips - A_and_b, a_or_b = n_tips - A_and_B, @@ -315,7 +325,7 @@ inline List jaccard_similarity(const RawMatrix &x, const RawMatrix &y, std::vector final_matching; final_matching.reserve(a.n_splits); - for (int16 i = 0; i < a.n_splits; ++i) { + for (split_int i = 0; i < a.n_splits; ++i) { const int match = (rowsol[i] < b.n_splits) ? static_cast(rowsol[i]) + 1 : NA_INTEGER; @@ -330,37 +340,69 @@ inline List jaccard_similarity(const RawMatrix &x, const RawMatrix &y, List msi_distance(const RawMatrix &x, const RawMatrix &y, const int32 n_tips) { const SplitList a(x), b(y); - const int16 most_splits = std::max(a.n_splits, b.n_splits); + const split_int most_splits = std::max(a.n_splits, b.n_splits); + const bool use_lookup_table = TreeDist::can_use_lookup_table(n_tips); constexpr cost max_score = BIG; - const double max_possible = lg2_unrooted[n_tips] - - lg2_rooted[int16((n_tips + 1) / 2)] - lg2_rooted[int16(n_tips / 2)]; + const double max_possible = use_lookup_table + ? TreeDist::lg2_unrooted[n_tips] - + TreeDist::lg2_rooted[split_int((n_tips + 1) / 2)] - + TreeDist::lg2_rooted[split_int(n_tips / 2)] + : TreeDist::lg2_unrooted_lookup(n_tips) - + TreeDist::lg2_rooted_lookup(split_int((n_tips + 1) / 2)) - + TreeDist::lg2_rooted_lookup(split_int(n_tips / 2)); const double score_over_possible = static_cast(max_score) / max_possible; const double possible_over_score = max_possible / max_score; cost_matrix score(most_splits); - splitbit different[SL_MAX_BINS]; - - for (int16 ai = 0; ai < a.n_splits; ++ai) { - for (int16 bi = 0; bi < b.n_splits; ++bi) { - int16 - n_different = 0, - n_a_only = 0, - n_a_and_b = 0 - ; - for (int16 bin = 0; bin < a.n_bins; ++bin) { - different[bin] = a.state[ai][bin] ^ b.state[bi][bin]; - n_different += count_bits(different[bin]); - n_a_only += count_bits(a.state[ai][bin] & different[bin]); - n_a_and_b += count_bits(a.state[ai][bin] & ~different[bin]); + std::vector different(a.n_bins); + + if (use_lookup_table) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); + for (split_int bi = 0; bi < b.n_splits; ++bi) { + split_int + n_different = 0, + n_a_only = 0, + n_a_and_b = 0 + ; + for (split_int bin = 0; bin < a.n_bins; ++bin) { + different[bin] = a.state[ai][bin] ^ b.state[bi][bin]; + n_different += count_bits(different[bin]); + n_a_only += count_bits(a.state[ai][bin] & different[bin]); + n_a_and_b += count_bits(a.state[ai][bin] & ~different[bin]); + } + const split_int n_same = n_tips - n_different; + + score(ai, bi) = cost(max_score - + (score_over_possible * + TreeDist::mmsi_score_table(n_same, n_a_and_b, n_different, n_a_only))); } - const int16 n_same = n_tips - n_different; - - score(ai, bi) = cost(max_score - - (score_over_possible * - TreeDist::mmsi_score(n_same, n_a_and_b, n_different, n_a_only))); + score.padRowAfterCol(ai, b.n_splits, max_score); + } + } else { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); + for (split_int bi = 0; bi < b.n_splits; ++bi) { + split_int + n_different = 0, + n_a_only = 0, + n_a_and_b = 0 + ; + for (split_int bin = 0; bin < a.n_bins; ++bin) { + different[bin] = a.state[ai][bin] ^ b.state[bi][bin]; + n_different += count_bits(different[bin]); + n_a_only += count_bits(a.state[ai][bin] & different[bin]); + n_a_and_b += count_bits(a.state[ai][bin] & ~different[bin]); + } + const split_int n_same = n_tips - n_different; + + score(ai, bi) = cost(max_score - + (score_over_possible * + TreeDist::mmsi_score(n_same, n_a_and_b, n_different, n_a_only))); + } + score.padRowAfterCol(ai, b.n_splits, max_score); } - score.padRowAfterCol(ai, b.n_splits, max_score); } score.padAfterRow(a.n_splits, max_score); @@ -377,7 +419,7 @@ List msi_distance(const RawMatrix &x, const RawMatrix &y, const int32 n_tips) { std::vector final_matching; final_matching.reserve(a.n_splits); - for (int16 i = 0; i < a.n_splits; ++i) { + for (split_int i = 0; i < a.n_splits; ++i) { const int match = (rowsol[i] < b.n_splits) ? static_cast(rowsol[i]) + 1 : NA_INTEGER; @@ -394,9 +436,9 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, const SplitList a(x); const SplitList b(y); const bool a_has_more_splits = (a.n_splits > b.n_splits); - const int16 most_splits = a_has_more_splits ? a.n_splits : b.n_splits; - const int16 a_extra_splits = a_has_more_splits ? most_splits - b.n_splits : 0; - const int16 b_extra_splits = a_has_more_splits ? 0 : most_splits - a.n_splits; + const split_int most_splits = a_has_more_splits ? a.n_splits : b.n_splits; + const split_int a_extra_splits = a_has_more_splits ? most_splits - b.n_splits : 0; + const split_int b_extra_splits = a_has_more_splits ? 0 : most_splits - a.n_splits; const double n_tips_reciprocal = 1.0 / n_tips; if (most_splits == 0 || n_tips == 0) { @@ -407,40 +449,41 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, constexpr cost max_score = BIG; constexpr double over_max_score = 1.0 / static_cast(max_score); const double max_over_tips = static_cast(max_score) * n_tips_reciprocal; - const double lg2_n = lg2[n_tips]; + const double lg2_n = TreeDist::lg2_lookup(n_tips); cost_matrix score(most_splits); double exact_match_score = 0; - int16 exact_matches = 0; + split_int exact_matches = 0; // vector zero-initializes [so does make_unique] // match will have one added to it so numbering follows R; hence 0 = UNMATCHED std::vector a_match(a.n_splits); - std::unique_ptr b_match = std::make_unique(b.n_splits); + std::unique_ptr b_match = std::make_unique(b.n_splits); - for (int16 ai = 0; ai < a.n_splits; ++ai) { + for (split_int ai = 0; ai < a.n_splits; ++ai) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); if (a_match[ai]) continue; - const int16 na = a.in_split[ai]; - const int16 nA = n_tips - na; + const split_int na = a.in_split[ai]; + const split_int nA = n_tips - na; const auto *a_row = a.state[ai]; - const double offset_a = lg2_n - lg2[na]; - const double offset_A = lg2_n - lg2[nA]; + const double offset_a = lg2_n - TreeDist::lg2_lookup(na); + const double offset_A = lg2_n - TreeDist::lg2_lookup(nA); - for (int16 bi = 0; bi < b.n_splits; ++bi) { + for (split_int bi = 0; bi < b.n_splits; ++bi) { // x divides tips into a|A; y divides tips into b|B - int16 a_and_b = 0; + split_int a_and_b = 0; const auto *b_row = b.state[bi]; - for (int16 bin = 0; bin < a.n_bins; ++bin) { + for (split_int bin = 0; bin < a.n_bins; ++bin) { a_and_b += count_bits(a_row[bin] & b_row[bin]); } - const int16 nb = b.in_split[bi]; - const int16 nB = n_tips - nb; - const int16 a_and_B = na - a_and_b; - const int16 A_and_b = nb - a_and_b; - const int16 A_and_B = nA - A_and_b; + const split_int nb = b.in_split[bi]; + const split_int nB = n_tips - nb; + const split_int a_and_B = na - a_and_b; + const split_int A_and_b = nb - a_and_b; + const split_int A_and_B = nA - A_and_b; if ((!a_and_B && !A_and_b) || (!a_and_b && !A_and_B)) { @@ -454,13 +497,13 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, && a_and_b == A_and_B) { score(ai, bi) = max_score; // Avoid rounding errors } else { - const double lg2_nb = lg2[nb]; - const double lg2_nB = lg2[nB]; + const double lg2_nb = TreeDist::lg2_lookup(nb); + const double lg2_nB = TreeDist::lg2_lookup(nB); const double ic_sum = - a_and_b * (lg2[a_and_b] + offset_a - lg2_nb) + - a_and_B * (lg2[a_and_B] + offset_a - lg2_nB) + - A_and_b * (lg2[A_and_b] + offset_A - lg2_nb) + - A_and_B * (lg2[A_and_B] + offset_A - lg2_nB); + a_and_b * (TreeDist::lg2_lookup(a_and_b) + offset_a - lg2_nb) + + a_and_B * (TreeDist::lg2_lookup(a_and_B) + offset_a - lg2_nB) + + A_and_b * (TreeDist::lg2_lookup(A_and_b) + offset_A - lg2_nb) + + A_and_B * (TreeDist::lg2_lookup(A_and_B) + offset_A - lg2_nB); // Division by n_tips converts n(A&B) to P(A&B) for each ic_element score(ai, bi) = max_score - static_cast(ic_sum * max_over_tips); @@ -478,7 +521,7 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, _["matching"] = a_match); } - const int16 lap_dim = most_splits - exact_matches; + const split_int lap_dim = most_splits - exact_matches; ASSERT(lap_dim > 0); std::vector rowsol; std::vector colsol; @@ -487,11 +530,11 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, cost_matrix small_score(lap_dim); if (exact_matches) { - int16 a_pos = 0; - for (int16 ai = 0; ai < a.n_splits; ++ai) { + split_int a_pos = 0; + for (split_int ai = 0; ai < a.n_splits; ++ai) { if (a_match[ai]) continue; - int16 b_pos = 0; - for (int16 bi = 0; bi < b.n_splits; ++bi) { + split_int b_pos = 0; + for (split_int bi = 0; bi < b.n_splits; ++bi) { if (b_match[bi]) continue; small_score(a_pos, b_pos) = score(ai, bi); b_pos++; @@ -505,9 +548,9 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, lap(lap_dim, small_score, rowsol, colsol)) * over_max_score; const double final_score = lap_score + (exact_match_score / n_tips); - std::unique_ptr lap_decode = std::make_unique(lap_dim); - int16 fuzzy_match = 0; - for (int16 bi = 0; bi < b.n_splits; ++bi) { + std::unique_ptr lap_decode = std::make_unique(lap_dim); + split_int fuzzy_match = 0; + for (split_int bi = 0; bi < b.n_splits; ++bi) { if (!b_match[bi]) { assert(fuzzy_match < lap_dim); lap_decode[fuzzy_match++] = bi + 1; @@ -517,13 +560,13 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, fuzzy_match = 0; std::vector final_matching; TreeDist::resize_uninitialized(final_matching, a.n_splits); - for (int16 i = 0; i < a.n_splits; ++i) { + for (split_int i = 0; i < a.n_splits; ++i) { if (a_match[i]) { final_matching[i] = a_match[i]; } else { assert(fuzzy_match < lap_dim); - const int16 row_idx = fuzzy_match++; - const int16 col_idx = rowsol[row_idx]; + const split_int row_idx = fuzzy_match++; + const split_int col_idx = rowsol[row_idx]; final_matching[i] = (col_idx >= lap_dim - a_extra_splits) ? NA_INTEGER : lap_decode[col_idx]; } @@ -533,8 +576,8 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, _["matching"] = final_matching); } else { - for (int16 ai = a.n_splits; ai < most_splits; ++ai) { - for (int16 bi = 0; bi < most_splits; ++bi) { + for (split_int ai = a.n_splits; ai < most_splits; ++ai) { + for (split_int bi = 0; bi < most_splits; ++bi) { score(ai, bi) = max_score; } } @@ -545,7 +588,7 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, std::vector final_matching; final_matching.reserve(a.n_splits); - for (int16 i = 0; i < a.n_splits; ++i) { + for (split_int i = 0; i < a.n_splits; ++i) { const int match = (rowsol[i] < b.n_splits) ? static_cast(rowsol[i]) + 1 : NA_INTEGER; @@ -560,12 +603,16 @@ List mutual_clustering(const RawMatrix &x, const RawMatrix &y, inline List shared_phylo (const RawMatrix &x, const RawMatrix &y, const int32 n_tips) { const SplitList a(x), b(y); - const int16 most_splits = std::max(a.n_splits, b.n_splits); - const int16 overlap_a = int16(n_tips + 1) / 2; // avoids promotion to int + const split_int most_splits = std::max(a.n_splits, b.n_splits); + const split_int overlap_a = split_int(n_tips + 1) / 2; // avoids promotion to int + const bool use_lookup_table = TreeDist::can_use_lookup_table(n_tips); constexpr cost max_score = BIG; - const double lg2_unrooted_n = lg2_unrooted[n_tips]; - const double best_overlap = TreeDist::one_overlap(overlap_a, n_tips / 2, n_tips); + const double lg2_unrooted_n = use_lookup_table ? TreeDist::lg2_unrooted[n_tips] : + TreeDist::lg2_unrooted_lookup(n_tips); + const double best_overlap = use_lookup_table + ? TreeDist::one_overlap_table(overlap_a, n_tips / 2, n_tips) + : TreeDist::one_overlap(overlap_a, n_tips / 2, n_tips); const double max_possible = lg2_unrooted_n - best_overlap; const double score_over_possible = max_score / max_possible; const double possible_over_score = max_possible / max_score; @@ -574,16 +621,32 @@ inline List shared_phylo (const RawMatrix &x, const RawMatrix &y, // In/out direction [i.e. 1/0 bit] is arbitrary. cost_matrix score(most_splits); - for (int16 ai = a.n_splits; ai--; ) { - for (int16 bi = b.n_splits; bi--; ) { - const double spi_over = TreeDist::spi_overlap( - a.state[ai], b.state[bi], n_tips, a.in_split[ai], b.in_split[bi], - a.n_bins); - - score(ai, bi) = spi_over == 0 ? max_score : - cost((spi_over - best_overlap) * score_over_possible); + if (use_lookup_table) { + for (split_int ai = a.n_splits; ai--; ) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); + for (split_int bi = b.n_splits; bi--; ) { + const double spi_over = TreeDist::spi_overlap_table( + a.state[ai], b.state[bi], n_tips, a.in_split[ai], b.in_split[bi], + a.n_bins); + + score(ai, bi) = spi_over == 0 ? max_score : + cost((spi_over - best_overlap) * score_over_possible); + } + score.padRowAfterCol(ai, b.n_splits, max_score); + } + } else { + for (split_int ai = a.n_splits; ai--; ) { + if ((ai & 1023) == 0) Rcpp::checkUserInterrupt(); + for (split_int bi = b.n_splits; bi--; ) { + const double spi_over = TreeDist::spi_overlap( + a.state[ai], b.state[bi], n_tips, a.in_split[ai], b.in_split[bi], + a.n_bins); + + score(ai, bi) = spi_over == 0 ? max_score : + cost((spi_over - best_overlap) * score_over_possible); + } + score.padRowAfterCol(ai, b.n_splits, max_score); } - score.padRowAfterCol(ai, b.n_splits, max_score); } score.padAfterRow(a.n_splits, max_score); @@ -601,7 +664,7 @@ inline List shared_phylo (const RawMatrix &x, const RawMatrix &y, std::vector final_matching; final_matching.reserve(a.n_splits); - for (int16 i = 0; i < a.n_splits; ++i) { + for (split_int i = 0; i < a.n_splits; ++i) { const int match = (rowsol[i] < b.n_splits) ? static_cast(rowsol[i]) + 1 : NA_INTEGER; diff --git a/src/tree_distances.h b/src/tree_distances.h index 8d4f7b4f..bead1aaf 100644 --- a/src/tree_distances.h +++ b/src/tree_distances.h @@ -19,60 +19,105 @@ constexpr splitbit ALL_ONES = (std::numeric_limits::max)(); namespace TreeDist { // Re-exported from mutual_clustering.h: - // ic_matching(int16 a, int16 b, int16 n) + // ic_matching(split_int a, split_int b, split_int n) + + void check_ntip(const int32 n); // See equation 16 in Meila 2007 (k' denoted K). // nkK is converted to pkK in the calling function when divided by n. - inline void add_ic_element(double& ic_sum, const int16 nkK, const int16 nk, - const int16 nK, const int16 n_tips, + inline void add_ic_element(double& ic_sum, const split_int nkK, + const split_int nk, const split_int nK, + const split_int n_tips, const double lg2_n) noexcept { if (nkK && nk && nK) { + ASSERT(n_tips > 0 && "n_tips must be positive"); + ASSERT(n_tips <= (std::numeric_limits::max)() / n_tips && + "nTip too large for int32 products in add_ic_element"); assert(!(nkK == nk && nkK == nK && nkK << 1 == n_tips)); const int32 numerator = nkK * n_tips; const int32 denominator = nk * nK; if (numerator != denominator) { - ic_sum += nkK * (lg2[nkK] + lg2_n - lg2[nk] - lg2[nK]); + ic_sum += nkK * (lg2_lookup(nkK) + lg2_n - lg2_lookup(nk) - + lg2_lookup(nK)); } } } + [[nodiscard]] inline bool can_use_lookup_table(const split_int n_tips) noexcept { + return n_tips <= static_cast(SL_MAX_TIPS); + } + // Returns lg2_unrooted[x] - lg2_trees_matching_split(y, x - y) - [[nodiscard]] inline double mmsi_pair_score(const int16 x, const int16 y) noexcept { - assert(SL_MAX_TIPS + 2 <= std::numeric_limits::max()); // verify int16 ok - + [[nodiscard]] inline double mmsi_pair_score_table(const split_int x, + const split_int y) noexcept { + ASSERT(can_use_lookup_table(x)); return lg2_unrooted[x] - (lg2_rooted[y] + lg2_rooted[x - y]); } - [[nodiscard]] inline double mmsi_score(const int16 n_same, const int16 n_a_and_b, - const int16 n_different, const int16 n_a_only) noexcept { + [[nodiscard]] inline double mmsi_pair_score(const split_int x, + const split_int y) noexcept { + return lg2_unrooted_lookup(x) - (lg2_rooted_lookup(y) + + lg2_rooted_lookup(x - y)); + } + + [[nodiscard]] inline double mmsi_score_table(const split_int n_same, + const split_int n_a_and_b, + const split_int n_different, + const split_int n_a_only) noexcept { + if (n_same == 0 || n_same == n_a_and_b) + return mmsi_pair_score_table(n_different, n_a_only); + if (n_different == 0 || n_different == n_a_only) + return mmsi_pair_score_table(n_same, n_a_and_b); + + const double score1 = mmsi_pair_score_table(n_same, n_a_and_b), + score2 = mmsi_pair_score_table(n_different, n_a_only); + + return (score1 > score2) ? score1 : score2; + } + + [[nodiscard]] inline double mmsi_score(const split_int n_same, + const split_int n_a_and_b, + const split_int n_different, + const split_int n_a_only) noexcept { if (n_same == 0 || n_same == n_a_and_b) return mmsi_pair_score(n_different, n_a_only); if (n_different == 0 || n_different == n_a_only) return mmsi_pair_score(n_same, n_a_and_b); - - const double - score1 = mmsi_pair_score(n_same, n_a_and_b), - score2 = mmsi_pair_score(n_different, n_a_only); - + + const double score1 = mmsi_pair_score(n_same, n_a_and_b), + score2 = mmsi_pair_score(n_different, n_a_only); + return (score1 > score2) ? score1 : score2; } - -[[nodiscard]] inline double one_overlap(const int16 a, const int16 b, const int16 n) noexcept { - assert(SL_MAX_TIPS + 2 <= std::numeric_limits::max()); // verify int16 ok + [[nodiscard]] inline double one_overlap_table(const split_int a, const split_int b, + const split_int n) noexcept { + ASSERT(can_use_lookup_table(n)); if (a == b) { return lg2_rooted[a] + lg2_rooted[n - a]; } - // Unify ab via lo/hi: removes an unpredictable branch. - const int16 lo = (a < b) ? a : b; - const int16 hi = (a < b) ? b : a; + const split_int lo = (a < b) ? a : b; + const split_int hi = (a < b) ? b : a; return lg2_rooted[hi] + lg2_rooted[n - lo] - lg2_rooted[hi - lo + 1]; } - - [[nodiscard]] inline double one_overlap_notb(const int16 a, const int16 n_minus_b, const int16 n) noexcept { - assert(SL_MAX_TIPS + 2 <= std::numeric_limits::max()); // verify int16 ok - const int16 b = n - n_minus_b; + + [[nodiscard]] inline double one_overlap(const split_int a, const split_int b, + const split_int n) noexcept { + if (a == b) { + return lg2_rooted_lookup(a) + lg2_rooted_lookup(n - a); + } + const split_int lo = (a < b) ? a : b; + const split_int hi = (a < b) ? b : a; + return lg2_rooted_lookup(hi) + lg2_rooted_lookup(n - lo) - + lg2_rooted_lookup(hi - lo + 1); + } + + [[nodiscard]] inline double one_overlap_notb_table(const split_int a, + const split_int n_minus_b, + const split_int n) noexcept { + ASSERT(can_use_lookup_table(n)); + const split_int b = n - n_minus_b; if (a == b) { return lg2_rooted[b] + lg2_rooted[n_minus_b]; } else if (a < b) { @@ -82,37 +127,64 @@ namespace TreeDist { } } + [[nodiscard]] inline double one_overlap_notb(const split_int a, + const split_int n_minus_b, + const split_int n) noexcept { + const split_int b = n - n_minus_b; + if (a == b) { + return lg2_rooted_lookup(b) + lg2_rooted_lookup(n_minus_b); + } else if (a < b) { + return lg2_rooted_lookup(b) + lg2_rooted_lookup(n - a) - + lg2_rooted_lookup(b - a + 1); + } else { + return lg2_rooted_lookup(a) + lg2_rooted_lookup(n_minus_b) - + lg2_rooted_lookup(a - b + 1); + } + } + + [[nodiscard]] inline double spi_overlap_table(const splitbit* a_state, + const splitbit* b_state, + const split_int n_tips, + const split_int in_a, + const split_int in_b, + const split_int n_bins) noexcept { + ASSERT(can_use_lookup_table(n_tips)); + split_int n_ab = 0; + for (split_int bin = 0; bin < n_bins; ++bin) { + n_ab += TreeTools::count_bits(a_state[bin] & b_state[bin]); + } + + if (n_ab == 0) { + return one_overlap_notb_table(in_a, in_b, n_tips); + } + if (n_ab == in_b || n_ab == in_a) { + return one_overlap_table(in_a, in_b, n_tips); + } + if (in_a + in_b - n_ab == n_tips) { + return one_overlap_notb_table(in_a, in_b, n_tips); + } + + return 0.0; + } -// Popcount-based: single pass over bins replaces 4 sequential boolean scans. + // Popcount-based: single pass over bins replaces 4 sequential boolean scans. // Counts n_ab = |A ∩ B| via hardware POPCNT, then derives all 4 Venn-diagram // region populations from arithmetic on n_ab, in_a, in_b, n_tips. [[nodiscard]] inline double spi_overlap(const splitbit* a_state, const splitbit* b_state, - const int16 n_tips, const int16 in_a, - const int16 in_b, const int16 n_bins) noexcept { - - assert(SL_MAX_BINS <= INT16_MAX); - - int16 n_ab = 0; - for (int16 bin = 0; bin < n_bins; ++bin) { + const split_int n_tips, const split_int in_a, + const split_int in_b, const split_int n_bins) noexcept { + split_int n_ab = 0; + for (split_int bin = 0; bin < n_bins; ++bin) { n_ab += TreeTools::count_bits(a_state[bin] & b_state[bin]); } - // n_a_only = in_a - n_ab (tips in A but not B) - // n_b_only = in_b - n_ab (tips in B but not A) - // n_neither = n_tips - in_a - in_b + n_ab (tips in neither) - // - // Return 0 when all 4 regions are populated (the common case for - // unrelated splits). Otherwise return the appropriate one_overlap score. - if (n_ab == 0) { return one_overlap_notb(in_a, in_b, n_tips); } if (n_ab == in_b || n_ab == in_a) { - // B ⊆ A (n_b_only == 0) or A ⊆ B (n_a_only == 0) return one_overlap(in_a, in_b, n_tips); } if (in_a + in_b - n_ab == n_tips) { - // A ∪ B covers all tips (n_neither == 0) return one_overlap_notb(in_a, in_b, n_tips); } diff --git a/tests/testthat/test-large-trees.R b/tests/testthat/test-large-trees.R new file mode 100644 index 00000000..5706455a --- /dev/null +++ b/tests/testthat/test-large-trees.R @@ -0,0 +1,67 @@ +test_that("Known-answer large-tree near-neighbours (>2048 tips)", { + + # Similar deterministic trees exercise shortcut paths and run quickly. + t1 <- as.phylo(0, 2050) + t2 <- as.phylo(1, 2050) + trees <- structure(list(t1, t2), class = "multiPhylo") + + # Known answer for adjacent `as.phylo()` trees: one non-shared split per tree. + rf <- RobinsonFoulds(t1, t2) + expect_equal(rf, 2) + expect_equal(as.matrix(RobinsonFoulds(trees))[2, 1], 2) + + # Other large-tree metrics should be finite and non-negative. + cid <- ClusteringInfoDistance(t1, t2) + expect_equal(cid, 0.01409, tolerance = 1e-5) + msd <- MatchingSplitDistance(t1, t2) + expect_equal(msd, 2) + irf <- InfoRobinsonFoulds(t1, t2) + expect_equal(irf, 23.999, tolerance = 1e-4) + + # Batch and pairwise paths must agree. + expect_equal(unname(as.matrix(ClusteringInfoDistance(trees))[2, 1]), + unname(cid), tolerance = 1e-8) + expect_equal(unname(as.matrix(MatchingSplitDistance(trees))[2, 1]), + unname(msd), tolerance = 1e-8) + expect_equal(unname(as.matrix(InfoRobinsonFoulds(trees))[2, 1]), + unname(irf), tolerance = 1e-8) +}) + +test_that("Large-tree (>SL_MAX_TIPS) non-table paths: PID, MSID, Jaccard, MCI", { + # 2052 tips: SL_MAX_TIPS (2048) is exceeded, forcing fallback (non-table) + # scoring paths in shared_phylo, msi_distance and the batch equivalents. + # Split size can reach 2051 > SL_MAX_TIPS + 1, triggering lg2_rooted_lookup + # fallback. Near-neighbour trees share most splits so the LAP is tiny. + t1 <- as.phylo(0, 2052) + t2 <- as.phylo(1, 2052) + trees <- structure(list(t1, t2), class = "multiPhylo") + + pid <- PhylogeneticInfoDistance(t1, t2) + expect_type(pid, "double") + expect_true(is.finite(pid)) + expect_gt(pid, 0) + + msid <- MatchingSplitInfoDistance(t1, t2) + expect_type(msid, "double") + expect_true(is.finite(msid)) + expect_gt(msid, 0) + + # Batch and pairwise paths must agree for PID and MSID + expect_equal(unname(as.matrix(PhylogeneticInfoDistance(trees))[2, 1]), + unname(pid), tolerance = 1e-8) + expect_equal(unname(as.matrix(MatchingSplitInfoDistance(trees))[2, 1]), + unname(msid), tolerance = 1e-8) + + # NyeSimilarity exercises jaccard_similarity serial interrupt path + jac <- NyeSimilarity(t1, t2) + expect_type(jac, "double") + expect_true(is.finite(jac)) + expect_gt(jac, 0) + + skip_on_cran() + skip_if_not(getOption("slowMode", FALSE)) + # reportMatching = TRUE forces the serial mutual_clustering path with interrupt + cid_match <- ClusteringInfoDistance(t1, t2, reportMatching = TRUE) + expect_true(is.integer(attr(cid_match, "matching"))) + expect_gt(length(attr(cid_match, "matching")), 0) +}) diff --git a/tests/testthat/test-split_info.R b/tests/testthat/test-split_info.R index a31dc695..887c0f61 100644 --- a/tests/testthat/test-split_info.R +++ b/tests/testthat/test-split_info.R @@ -23,11 +23,11 @@ test_that("Split info calculated", { expect_error(consensus_info(trees, TRUE, -7)) expect_error( consensus_info(list(rtree(20000), rtree(20000)), TRUE, p = 1), - "This many leaves are not yet supported" + "not yet supported for consensus info" ) expect_error( consensus_info(list(rtree(20000), rtree(20000)), FALSE, p = 1), - "This many leaves are not yet supported" + "not yet supported for consensus info" ) diff --git a/tests/testthat/test-tree_distance_nni.R b/tests/testthat/test-tree_distance_nni.R index 20bcda23..c479f4bb 100644 --- a/tests/testthat/test-tree_distance_nni.R +++ b/tests/testthat/test-tree_distance_nni.R @@ -9,27 +9,30 @@ test_that("NNIDist() handles exceptions", { PectinateTree(as.character(1:8)))), "trees must bear identical labels") # R-level guard catches too-many-tips - expect_error(NNIDist(PectinateTree(40000), BalancedTree(40000)), "so many tips") + expect_error(NNIDist(PectinateTree(40000), BalancedTree(40000)), + "not yet supported for NNI") expect_error(NNIDist(BalancedTree(5), RootOnNode(BalancedTree(5), 1))) }) -test_that("NNIDist() at NNI_MAX_TIPS", { - maxTips <- 32768 - more <- maxTips + 1 +test_that("NNIDist() at max tips", { + maxTips <- 32768L + more <- maxTips + 1L expect_error(.NNIDistSingle(PectinateTree(more), BalancedTree(more), more), - "so many tips") - goingQuickly <- TRUE - skip_if(goingQuickly) + "not yet supported for NNI") + skip_if_not(getOption("slowMode", FALSE)) heapTips <- 16384 + 1 + skip_if(maxTips < heapTips) skip_if_not_installed("testthat", "3.2.2") expect_no_error(.NNIDistSingle(PectinateTree(heapTips), BalancedTree(heapTips), heapTips)) + skip_if(maxTips < 32768L) + maxTips <- 32768L n <- .NNIDistSingle(PectinateTree(maxTips), BalancedTree(maxTips), - maxTips) + maxTips) expect_gt(n[["best_upper"]], n[["best_lower"]]) if (!is.na(n[["tight_upper"]])) { expect_gte(n[["tight_upper"]], n[["best_upper"]]) @@ -195,3 +198,18 @@ test_that("NNIDiameter() is sane", { expect_equal(NNIDiameter(as.phylo(0:1, 6)), NNIDiameter(c(6, 6))) }) + +test_that("cpp_nni_distance C++ guards fire for out-of-range nTip", { + # .NNIDistSingle calls .CheckMaxTips first, so the C++ guards inside + # cpp_nni_distance are only reachable via a direct call. + tree1 <- PectinateTree(5) + tree2 <- BalancedTree(5) + edge1 <- Postorder(tree1$edge) + edge2 <- Postorder(tree2$edge) + # Too many tips: fires nTip[0] > NNI_MAX_TIPS in C++ + expect_error(cpp_nni_distance(edge1, edge2, 32769L), + "not yet supported for NNI") + # Negative nTip: fires nTip[0] < 0 in C++ + expect_error(cpp_nni_distance(edge1, edge2, -1L), + "invalid") +}) diff --git a/tests/testthat/test-tree_distance_utilities.R b/tests/testthat/test-tree_distance_utilities.R index 61f513d8..12c82bdb 100644 --- a/tests/testthat/test-tree_distance_utilities.R +++ b/tests/testthat/test-tree_distance_utilities.R @@ -32,24 +32,48 @@ test_that("CalculateTreeDistance() errs appropriately", { }) test_that("Tip-count guard is applied consistently", { - expect_no_error(.AssertNtipSupported(1000L)) - expect_no_error(.AssertNtipSupported(32766L)) - expect_no_error(.AssertNtipSupported(32767L)) - expect_error(.AssertNtipSupported(32768L), - "This many tips are not \\(yet\\) supported\\.") + maxTips <- cpp_max_tips() + expect_true(is.numeric(maxTips)) + expect_gt(maxTips, 0) + + expect_no_error(.CheckMaxTips(min(1000L, maxTips))) + expect_no_error(.CheckMaxTips(maxTips)) + expect_error(.CheckMaxTips(as.double(maxTips) + 1), + "Trees with > .* tips are not yet supported") + expect_no_error(.CheckMaxTips(32768L, "NNI")) + expect_error(.CheckMaxTips(32769L, "NNI"), + "not yet supported for NNI") splits8 <- unclass(as.Splits(BalancedTree(8))) - expect_error(cpp_robinson_foulds_distance(splits8, splits8, 32768L), - "This many tips are not \\(yet\\) supported\\.") - expect_error(cpp_robinson_foulds_info(splits8, splits8, 32768L), - "This many tips are not \\(yet\\) supported\\.") + expect_no_error(cpp_robinson_foulds_distance(splits8, splits8, 8L)) + expect_error(cpp_robinson_foulds_distance(splits8, splits8, -1L), "invalid") + expect_no_error(cpp_robinson_foulds_info(splits8, splits8, 8L)) trees <- list(BalancedTree(8), PectinateTree(8)) class(trees) <- "multiPhylo" - expect_error( - .SplitDistanceAllPairs(RobinsonFouldsSplits, trees, letters[1:8], 32768L), - "This many tips are not \\(yet\\) supported\\." + tipLabels <- TipLabels(trees[[1]]) + expect_no_error( + .SplitDistanceAllPairs(RobinsonFouldsSplits, trees, tipLabels, 8L) + ) +}) + +test_that("Interrupt and tip-limit guards are wired in C++ distance paths", { + serial_path <- testthat::test_path("..", "..", "src", "tree_distances.cpp") + batch_path <- testthat::test_path("..", "..", "src", "pairwise_distances.cpp") + skip_if_not( + file.exists(serial_path) && file.exists(batch_path), + "C++ source files unavailable in installed-package checks" ) + + serial_src <- readLines(serial_path, warn = FALSE) + interrupt_lines <- grep("checkUserInterrupt\\(", serial_src, value = TRUE) + throttled_lines <- grep("\\(ai & 1023\\) == 0", serial_src, value = TRUE) + expect_gte(length(interrupt_lines), 7L) + expect_gte(length(throttled_lines), 7L) + + batch_src <- readLines(batch_path, warn = FALSE) + batch_guard_lines <- grep("TreeDist::check_ntip\\(", batch_src, value = TRUE) + expect_gte(length(batch_guard_lines), 14L) }) test_that("CalculateTreeDistance() handles splits appropriately", {