From 1c0a0c88e50b09c96ee35a165164eb5412d6cddd Mon Sep 17 00:00:00 2001 From: root Date: Mon, 6 Apr 2026 06:13:41 +0300 Subject: [PATCH 1/5] Wavelet rank and select implementation --- include/pixie/wavelet_tree.h | 108 +++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 include/pixie/wavelet_tree.h diff --git a/include/pixie/wavelet_tree.h b/include/pixie/wavelet_tree.h new file mode 100644 index 0000000..ac10965 --- /dev/null +++ b/include/pixie/wavelet_tree.h @@ -0,0 +1,108 @@ +#pragma once + +#include +#include + +namespace pixie { + +class WaveletTree { + using node_index_t = size_t; + struct WaveletNode { + static const node_index_t kNil = std::numeric_limits::max(); + node_index_t parent = kNil, left_child = kNil, right_child = kNil; + uint64_t middle; + std::vector bit_vector_; + BitVector data; + WaveletNode(uint64_t middle, + std::vector&& bit_vector, + size_t num_bits) + : middle(middle), + bit_vector_(std::move(bit_vector)), + data(std::span{bit_vector_}, num_bits) {} + }; + + size_t alphabet_size_, data_size_; + node_index_t root_; + std::vector nodes_; + std::vector leaves_; + + node_index_t BuildNode(size_t begin, + size_t end, + std::span data, + node_index_t parent) { + if (end - begin == 1) { + leaves_[begin] = parent; + return WaveletNode::kNil; + } + + size_t middle = begin + (end - begin) / 2; + std::vector bit_vector; + bit_vector.resize((data.size() + 63) / 64); + for (size_t i = 0; i < data.size(); i++) { + if (data[i] >= middle) { + bit_vector[i / 64] |= 1ull << (i % 64); + } + } + + node_index_t result = nodes_.size(); + nodes_.emplace_back(middle, std::move(bit_vector), data.size()); + auto cut = std::stable_partition( + data.begin(), data.end(), [middle](uint64_t x) { return x < middle; }); + nodes_[result].parent = parent; + nodes_[result].left_child = + BuildNode(begin, middle, std::span{data.begin(), cut}, result); + nodes_[result].right_child = + BuildNode(middle, end, std::span{cut, data.end()}, result); + + return result; + } + + public: + WaveletTree(size_t alphabet_size, std::span data) + : alphabet_size_(alphabet_size), + data_size_(data.size()), + leaves_(alphabet_size_, WaveletNode::kNil) { + if (alphabet_size > 0) { + std::vector data_copy(data.begin(), data.end()); + nodes_.reserve(alphabet_size_); + root_ = BuildNode(0, alphabet_size_, data_copy, WaveletNode::kNil); + } else { + root_ = WaveletNode::kNil; + } + } + + size_t rank(uint64_t symbol, size_t pos) const { + if (symbol >= alphabet_size_) [[unlikely]] { + return 0; + } + for (node_index_t current = root_; current != WaveletNode::kNil;) { + const WaveletNode& node = nodes_[current]; + if (symbol < node.middle) { + pos = node.data.rank0(pos); + current = node.left_child; + } else { + pos = node.data.rank(pos); + current = node.right_child; + } + } + return pos; + } + + size_t select(uint64_t symbol, size_t rank) { + if (symbol >= alphabet_size_) [[unlikely]] { + return data_size_; + } + node_index_t current = leaves_[symbol]; + for (; current != WaveletNode::kNil; current = nodes_[current].parent) { + const WaveletNode& node = nodes_[current]; + if (symbol < node.middle) { + rank = node.data.select0(rank) + 1; + } else { + rank = node.data.select(rank) + 1; + } + } + return rank - 1; + } +}; + +} // namespace pixie From 6e9a2f3173fe5bb0bcea3d42828fc2d55059b848 Mon Sep 17 00:00:00 2001 From: NikolayMishukov Date: Mon, 20 Apr 2026 07:54:40 +0300 Subject: [PATCH 2/5] add wavelet-tree tests --- CMakeLists.txt | 9 +++ include/pixie/utils.h | 11 ++++ include/pixie/wavelet_tree.h | 2 +- src/tests/wavelet_tree_tests.cpp | 103 +++++++++++++++++++++++++++++++ 4 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 src/tests/wavelet_tree_tests.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index c91128d..bd63586 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -148,6 +148,15 @@ if(PIXIE_TESTS) gtest gtest_main ${PIXIE_DIAGNOSTICS_LIBS}) + + add_executable(wavelet_tree_tests + src/tests/wavelet_tree_tests.cpp) + target_include_directories(wavelet_tree_tests + PUBLIC include) + target_link_libraries(wavelet_tree_tests + gtest + gtest_main + ${PIXIE_DIAGNOSTICS_LIBS}) endif() # --------------------------------------------------------------------------- diff --git a/include/pixie/utils.h b/include/pixie/utils.h index 24116d5..08e722b 100644 --- a/include/pixie/utils.h +++ b/include/pixie/utils.h @@ -59,6 +59,17 @@ std::vector adj_to_louds( return louds; } +std::vector generate_random_data(size_t data_size, + size_t alphabet_size, + std::mt19937_64& rng) { + std::vector data(data_size); + std::uniform_int_distribution alphabet(0, alphabet_size - 1); + for (size_t i = 0; i < data_size; i++) { + data[i] = alphabet(rng); + } + return data; +} + struct AdjListNode { size_t number; }; diff --git a/include/pixie/wavelet_tree.h b/include/pixie/wavelet_tree.h index ac10965..a563bda 100644 --- a/include/pixie/wavelet_tree.h +++ b/include/pixie/wavelet_tree.h @@ -8,7 +8,7 @@ namespace pixie { class WaveletTree { using node_index_t = size_t; struct WaveletNode { - static const node_index_t kNil = std::numeric_limits::max(); + static constexpr node_index_t kNil = std::numeric_limits::max(); node_index_t parent = kNil, left_child = kNil, right_child = kNil; uint64_t middle; std::vector bit_vector_; diff --git a/src/tests/wavelet_tree_tests.cpp b/src/tests/wavelet_tree_tests.cpp new file mode 100644 index 0000000..fa148f9 --- /dev/null +++ b/src/tests/wavelet_tree_tests.cpp @@ -0,0 +1,103 @@ +#include +#include +#include + +#include +#include +#include +#include + +using pixie::WaveletTree; + +TEST(WaveletTreeTest, BasicSelect) { + const std::vector data = {3, 2, 0, 3, 1, 1, 2}; + size_t data_size = 7, alphabet_size = 4; + + std::vector> rank(alphabet_size); + for (size_t i = 0; i < data_size; i++) { + rank[data[i]].push_back(i); + } + + WaveletTree wavelet_tree(alphabet_size, data); + + for (uint64_t symb = 0; symb < alphabet_size; symb++) { + for (size_t i = 0; i <= rank[symb].size(); i++) { + uint64_t exp = i == rank[symb].size() ? data_size : rank[symb][i]; + uint64_t act = wavelet_tree.select(symb, i + 1); + EXPECT_EQ(act, exp); + } + } +} + +TEST(WaveletTreeTest, BasicRank) { + const std::vector data = {3, 2, 0, 3, 1, 1, 2}; + size_t data_size = 7, alphabet_size = 4; + + std::vector count(alphabet_size); + + WaveletTree wavelet_tree(alphabet_size, data); + for (size_t i = 0; i <= data_size; i++) { + for (uint64_t symb = 0; symb < alphabet_size; symb++) { + uint64_t exp = count[symb]; + uint64_t act = wavelet_tree.rank(symb, i); + EXPECT_EQ(act, exp); + } + + if (i == data_size) { + break; + } + count[data[i]]++; + } +} + +TEST(WaveletTreeTest, SmokeSelect) { + std::vector> rank; + for (size_t data_size = 8; data_size < (1 << 22); data_size <<= 1) { + size_t alphabet_size = std::min(1024ull, data_size); + std::mt19937_64 rng(239); + std::vector data = + generate_random_data(data_size, alphabet_size, rng); + + rank.assign(alphabet_size, {}); + for (size_t i = 0; i < data_size; i++) { + rank[data[i]].push_back(i); + } + + WaveletTree wavelet_tree(alphabet_size, data); + + for (uint64_t symb = 0; symb < alphabet_size; symb++) { + for (size_t i = 0; i <= rank[symb].size(); i++) { + uint64_t exp = i == rank[symb].size() ? data_size : rank[symb][i]; + uint64_t act = wavelet_tree.select(symb, i + 1); + EXPECT_EQ(act, exp); + } + } + } +} + +TEST(WaveletTreeTest, SmokeRank) { + std::vector count; + for (size_t data_size = 8; data_size < (1 << 22); data_size <<= 1) { + size_t alphabet_size = std::min(1024ull, data_size); + std::mt19937_64 rng(239); + std::vector data = + generate_random_data(data_size, alphabet_size, rng); + std::vector query = + generate_random_data(data_size + 1, alphabet_size, rng); + + count.assign(alphabet_size, 0); + + WaveletTree wavelet_tree(alphabet_size, data); + for (size_t i = 0; i <= data_size; i++) { + uint64_t symb = query[i]; + uint64_t exp = count[symb]; + uint64_t act = wavelet_tree.rank(symb, i); + EXPECT_EQ(act, exp); + + if (i == data_size) { + break; + } + count[data[i]]++; + } + } +} \ No newline at end of file From 532ae3d10126efaa25bd8f45767d84717d77889e Mon Sep 17 00:00:00 2001 From: NikolayMishukov Date: Mon, 20 Apr 2026 12:51:23 +0300 Subject: [PATCH 3/5] add wavelet-tree benchmarks --- CMakeLists.txt | 9 +++ include/pixie/wavelet_tree.h | 8 +- src/benchmarks/wavelet_tree_benchmarks.cpp | 89 ++++++++++++++++++++++ src/tests/wavelet_tree_tests.cpp | 8 +- 4 files changed, 108 insertions(+), 6 deletions(-) create mode 100644 src/benchmarks/wavelet_tree_benchmarks.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index bd63586..f2502c3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -201,6 +201,15 @@ if(PIXIE_BENCHMARKS) benchmark_main ${PIXIE_DIAGNOSTICS_LIBS}) + add_executable(wavelet_tree_benchmarks + src/benchmarks/wavelet_tree_benchmarks.cpp) + target_include_directories(wavelet_tree_benchmarks + PUBLIC include) + target_link_libraries(wavelet_tree_benchmarks + benchmark + benchmark_main + ${PIXIE_DIAGNOSTICS_LIBS}) + add_executable(alignment_comparison src/benchmarks/alignment_comparison.cpp) target_include_directories(alignment_comparison diff --git a/include/pixie/wavelet_tree.h b/include/pixie/wavelet_tree.h index a563bda..3d60a1a 100644 --- a/include/pixie/wavelet_tree.h +++ b/include/pixie/wavelet_tree.h @@ -34,6 +34,12 @@ class WaveletTree { leaves_[begin] = parent; return WaveletNode::kNil; } + if(data.empty()){ + for(size_t symb = begin; symb < end; symb++){ + leaves_[symb] = parent; + } + return WaveletNode::kNil; + } size_t middle = begin + (end - begin) / 2; std::vector bit_vector; @@ -88,7 +94,7 @@ class WaveletTree { return pos; } - size_t select(uint64_t symbol, size_t rank) { + size_t select(uint64_t symbol, size_t rank) const { if (symbol >= alphabet_size_) [[unlikely]] { return data_size_; } diff --git a/src/benchmarks/wavelet_tree_benchmarks.cpp b/src/benchmarks/wavelet_tree_benchmarks.cpp new file mode 100644 index 0000000..b25d175 --- /dev/null +++ b/src/benchmarks/wavelet_tree_benchmarks.cpp @@ -0,0 +1,89 @@ +#include +#include +#include + +#include + +using pixie::WaveletTree; + +static void BM_WaveletTreeSelect(benchmark::State& state) { + size_t data_size = state.range(0), alphabet_size = 1024, query = data_size; + std::mt19937_64 rng(239); + + for (auto _ : state) { + state.PauseTiming(); + + std::vector data = + generate_random_data(data_size, alphabet_size, rng); + std::vector query_symbol = + generate_random_data(query, alphabet_size, rng); + std::vector count(alphabet_size), query_pos(query); + for (auto symb : data) { + count[symb]++; + } + for (size_t i = 0; i < query; i++) { + query_pos[i] = 1 + std::uniform_int_distribution( + 0, count[query_symbol[i]])(rng); + } + + state.ResumeTiming(); + + WaveletTree wavelet_tree(alphabet_size, data); + benchmark::DoNotOptimize(wavelet_tree); + + for (size_t i = 0; i < query; i++) { + size_t select = wavelet_tree.select(query_symbol[i], query_pos[i]); + benchmark::DoNotOptimize(select); + } + } +} + +static void BM_WaveletTreeRank(benchmark::State& state) { + size_t data_size = state.range(0), alphabet_size = 1024, query = data_size; + std::mt19937_64 rng(239); + + for (auto _ : state) { + state.PauseTiming(); + + std::vector data = + generate_random_data(data_size, alphabet_size, rng); + std::vector query_symbol = + generate_random_data(query, alphabet_size, rng), + query_pos = + generate_random_data(query, data_size + 1, rng); + + state.ResumeTiming(); + + WaveletTree wavelet_tree(alphabet_size, data); + benchmark::DoNotOptimize(wavelet_tree); + + for (size_t i = 0; i < query; i++) { + size_t rank = wavelet_tree.rank(query_symbol[i], query_pos[i]); + benchmark::DoNotOptimize(rank); + } + } +} + +BENCHMARK(BM_WaveletTreeSelect) + ->ArgNames({"data_size"}) + ->RangeMultiplier(2) + ->Range(1ull << 8, 1ull << 18) + ->Iterations(100); + +BENCHMARK(BM_WaveletTreeSelect) + ->ArgNames({"data_size"}) + ->RangeMultiplier(2) + ->Range(1ull << 18, 1ull << 26) + ->Iterations(10); + +BENCHMARK(BM_WaveletTreeRank) + ->ArgNames({"data_size"}) + ->RangeMultiplier(2) + ->Range(1ull << 8, 1ull << 18) + ->Iterations(100); + +BENCHMARK(BM_WaveletTreeRank) + ->ArgNames({"data_size"}) + ->RangeMultiplier(2) + ->Range(1ull << 18, 1ull << 26) + ->Iterations(10); diff --git a/src/tests/wavelet_tree_tests.cpp b/src/tests/wavelet_tree_tests.cpp index fa148f9..87e58c0 100644 --- a/src/tests/wavelet_tree_tests.cpp +++ b/src/tests/wavelet_tree_tests.cpp @@ -2,10 +2,7 @@ #include #include -#include -#include #include -#include using pixie::WaveletTree; @@ -53,7 +50,7 @@ TEST(WaveletTreeTest, BasicRank) { TEST(WaveletTreeTest, SmokeSelect) { std::vector> rank; for (size_t data_size = 8; data_size < (1 << 22); data_size <<= 1) { - size_t alphabet_size = std::min(1024ull, data_size); + size_t alphabet_size = 1024; std::mt19937_64 rng(239); std::vector data = generate_random_data(data_size, alphabet_size, rng); @@ -78,7 +75,7 @@ TEST(WaveletTreeTest, SmokeSelect) { TEST(WaveletTreeTest, SmokeRank) { std::vector count; for (size_t data_size = 8; data_size < (1 << 22); data_size <<= 1) { - size_t alphabet_size = std::min(1024ull, data_size); + size_t alphabet_size = 1024; std::mt19937_64 rng(239); std::vector data = generate_random_data(data_size, alphabet_size, rng); @@ -88,6 +85,7 @@ TEST(WaveletTreeTest, SmokeRank) { count.assign(alphabet_size, 0); WaveletTree wavelet_tree(alphabet_size, data); + for (size_t i = 0; i <= data_size; i++) { uint64_t symb = query[i]; uint64_t exp = count[symb]; From 6da207196dd5d65775be4d7400bbc14f89459d05 Mon Sep 17 00:00:00 2001 From: NikolayMishukov Date: Mon, 20 Apr 2026 14:26:13 +0300 Subject: [PATCH 4/5] add get-segment --- include/pixie/wavelet_tree.h | 58 +++++++++++++++++++++++++++++--- src/tests/wavelet_tree_tests.cpp | 38 +++++++++++++++++++++ 2 files changed, 92 insertions(+), 4 deletions(-) diff --git a/include/pixie/wavelet_tree.h b/include/pixie/wavelet_tree.h index 3d60a1a..ab8cb55 100644 --- a/include/pixie/wavelet_tree.h +++ b/include/pixie/wavelet_tree.h @@ -1,6 +1,7 @@ #pragma once #include + #include namespace pixie { @@ -8,7 +9,8 @@ namespace pixie { class WaveletTree { using node_index_t = size_t; struct WaveletNode { - static constexpr node_index_t kNil = std::numeric_limits::max(); + static constexpr node_index_t kNil = + std::numeric_limits::max(); node_index_t parent = kNil, left_child = kNil, right_child = kNil; uint64_t middle; std::vector bit_vector_; @@ -34,8 +36,8 @@ class WaveletTree { leaves_[begin] = parent; return WaveletNode::kNil; } - if(data.empty()){ - for(size_t symb = begin; symb < end; symb++){ + if (data.empty()) { + for (size_t symb = begin; symb < end; symb++) { leaves_[symb] = parent; } return WaveletNode::kNil; @@ -63,6 +65,44 @@ class WaveletTree { return result; } + void copySegmentContent(node_index_t node, + size_t begin, + size_t end, + std::span dst, + std::span tmp) const { + if (begin == end) { + return; + } + size_t rank = nodes_[node].data.rank(begin), rank0 = begin - rank; + size_t right = nodes_[node].data.rank(end) - rank, + left = (end - begin) - right; + + if (nodes_[node].left_child == WaveletNode::kNil) { + std::fill(tmp.begin(), tmp.begin() + static_cast(left), + nodes_[node].middle - 1); + } else { + copySegmentContent(nodes_[node].left_child, rank0, rank0 + left, + tmp.subspan(0, left), dst.subspan(0, left)); + } + if (nodes_[node].right_child == WaveletNode::kNil) { + std::fill(tmp.begin() + static_cast(left), tmp.end(), + nodes_[node].middle); + } else { + copySegmentContent(nodes_[node].right_child, rank, rank + right, + tmp.subspan(left, right), dst.subspan(left, right)); + } + + size_t j = 0, k = left; + const auto& bit_vector = nodes_[node].bit_vector_; + for (size_t i = begin; i < end; i++) { + if ((bit_vector[i / 64] >> (i % 64)) & 1) { + dst[i - begin] = tmp[k++]; + } else { + dst[i - begin] = tmp[j++]; + } + } + } + public: WaveletTree(size_t alphabet_size, std::span data) : alphabet_size_(alphabet_size), @@ -95,7 +135,7 @@ class WaveletTree { } size_t select(uint64_t symbol, size_t rank) const { - if (symbol >= alphabet_size_) [[unlikely]] { + if (symbol >= alphabet_size_ || data_size_ == 0) [[unlikely]] { return data_size_; } node_index_t current = leaves_[symbol]; @@ -109,6 +149,16 @@ class WaveletTree { } return rank - 1; } + + std::vector getSegment(size_t begin, size_t end) const { + auto length = static_cast(end - begin); + std::vector result(2 * length); + copySegmentContent(root_, begin, end, + std::span{result.begin(), result.begin() + length}, + std::span{result.begin() + length, result.end()}); + result.resize(length); + return result; + } }; } // namespace pixie diff --git a/src/tests/wavelet_tree_tests.cpp b/src/tests/wavelet_tree_tests.cpp index 87e58c0..61372c2 100644 --- a/src/tests/wavelet_tree_tests.cpp +++ b/src/tests/wavelet_tree_tests.cpp @@ -47,6 +47,23 @@ TEST(WaveletTreeTest, BasicRank) { } } +TEST(WaveletTreeTest, BasicSegment) { + const std::vector data = {3, 2, 0, 3, 1, 1, 2}; + size_t data_size = 7, alphabet_size = 4; + + WaveletTree wavelet_tree(alphabet_size, data); + + for (size_t begin = 0; begin <= data_size; begin++) { + for (size_t end = begin; end <= data_size; end++) { + auto segment = wavelet_tree.getSegment(begin, end); + EXPECT_EQ(segment.size(), end - begin); + for(size_t i = 0; i < end - begin; i++){ + EXPECT_EQ(segment[i], data[begin + i]); + } + } + } +} + TEST(WaveletTreeTest, SmokeSelect) { std::vector> rank; for (size_t data_size = 8; data_size < (1 << 22); data_size <<= 1) { @@ -98,4 +115,25 @@ TEST(WaveletTreeTest, SmokeRank) { count[data[i]]++; } } +} + + +TEST(WaveletTreeTest, SmokeSegment) { + size_t data_size = 256, alphabet_size = 100; + + std::mt19937_64 rng(239); + std::vector data = + generate_random_data(data_size, alphabet_size, rng); + + WaveletTree wavelet_tree(alphabet_size, data); + + for (size_t begin = 0; begin <= data_size; begin++) { + for (size_t end = begin; end <= data_size; end++) { + auto segment = wavelet_tree.getSegment(begin, end); + EXPECT_EQ(segment.size(), end - begin); + for(size_t i = 0; i < end - begin; i++){ + EXPECT_EQ(segment[i], data[begin + i]); + } + } + } } \ No newline at end of file From e4bd7a5c89d3ac9f8adba722887308c6515ed8fa Mon Sep 17 00:00:00 2001 From: NikolayMishukov Date: Mon, 20 Apr 2026 22:21:00 +0300 Subject: [PATCH 5/5] add Huffman build --- include/pixie/wavelet_tree.h | 113 ++++++++++++++++++++++++++----- scripts/coverage_report.sh | 1 + src/tests/wavelet_tree_tests.cpp | 63 +++++++++-------- 3 files changed, 131 insertions(+), 46 deletions(-) diff --git a/include/pixie/wavelet_tree.h b/include/pixie/wavelet_tree.h index ab8cb55..f9d2bda 100644 --- a/include/pixie/wavelet_tree.h +++ b/include/pixie/wavelet_tree.h @@ -6,6 +6,8 @@ namespace pixie { +enum WaveletTreeBuildType { Standard, Huffman }; + class WaveletTree { using node_index_t = size_t; struct WaveletNode { @@ -27,11 +29,16 @@ class WaveletTree { node_index_t root_; std::vector nodes_; std::vector leaves_; - - node_index_t BuildNode(size_t begin, - size_t end, - std::span data, - node_index_t parent) { + std::vector permutation_, inverse_permutation_; + + node_index_t BuildNode( + size_t begin, + size_t end, + std::span data, + node_index_t parent, + const std::function& get_middle = [](auto) { + return -1ull; + }) { if (end - begin == 1) { leaves_[begin] = parent; return WaveletNode::kNil; @@ -43,24 +50,26 @@ class WaveletTree { return WaveletNode::kNil; } - size_t middle = begin + (end - begin) / 2; + node_index_t result = nodes_.size(); + size_t middle = get_middle(result); + middle = begin + (middle == -1ull ? (end - begin) / 2 : middle); std::vector bit_vector; bit_vector.resize((data.size() + 63) / 64); for (size_t i = 0; i < data.size(); i++) { - if (data[i] >= middle) { + if (permutation_[data[i]] >= middle) { bit_vector[i / 64] |= 1ull << (i % 64); } } - node_index_t result = nodes_.size(); nodes_.emplace_back(middle, std::move(bit_vector), data.size()); auto cut = std::stable_partition( - data.begin(), data.end(), [middle](uint64_t x) { return x < middle; }); + data.begin(), data.end(), + [middle, this](uint64_t x) { return permutation_[x] < middle; }); nodes_[result].parent = parent; - nodes_[result].left_child = - BuildNode(begin, middle, std::span{data.begin(), cut}, result); + nodes_[result].left_child = BuildNode( + begin, middle, std::span{data.begin(), cut}, result, get_middle); nodes_[result].right_child = - BuildNode(middle, end, std::span{cut, data.end()}, result); + BuildNode(middle, end, std::span{cut, data.end()}, result, get_middle); return result; } @@ -79,14 +88,14 @@ class WaveletTree { if (nodes_[node].left_child == WaveletNode::kNil) { std::fill(tmp.begin(), tmp.begin() + static_cast(left), - nodes_[node].middle - 1); + inverse_permutation_[nodes_[node].middle - 1]); } else { copySegmentContent(nodes_[node].left_child, rank0, rank0 + left, tmp.subspan(0, left), dst.subspan(0, left)); } if (nodes_[node].right_child == WaveletNode::kNil) { std::fill(tmp.begin() + static_cast(left), tmp.end(), - nodes_[node].middle); + inverse_permutation_[nodes_[node].middle]); } else { copySegmentContent(nodes_[node].right_child, rank, rank + right, tmp.subspan(left, right), dst.subspan(left, right)); @@ -104,16 +113,79 @@ class WaveletTree { } public: - WaveletTree(size_t alphabet_size, std::span data) + WaveletTree(size_t alphabet_size, + std::span data, + WaveletTreeBuildType build_type = WaveletTreeBuildType::Standard) : alphabet_size_(alphabet_size), data_size_(data.size()), leaves_(alphabet_size_, WaveletNode::kNil) { - if (alphabet_size > 0) { + if (alphabet_size == 0) { + root_ = WaveletNode::kNil; + return; + } + nodes_.reserve(alphabet_size_); + if (build_type == WaveletTreeBuildType::Standard) { + permutation_.resize(alphabet_size); + inverse_permutation_.resize(alphabet_size); + std::iota(permutation_.begin(), permutation_.end(), 0); + std::iota(inverse_permutation_.begin(), inverse_permutation_.end(), 0); + std::vector data_copy(data.begin(), data.end()); - nodes_.reserve(alphabet_size_); root_ = BuildNode(0, alphabet_size_, data_copy, WaveletNode::kNil); } else { - root_ = WaveletNode::kNil; + struct Node { + size_t size, left, right; + }; + std::vector huffman_nodes(alphabet_size_, {0, 0, 0}); + for (auto symb : data) { + huffman_nodes[symb].size++; + } + + using elem_t = std::pair; + std::priority_queue, std::greater<>> queue; + for (size_t i = 0; i < alphabet_size_; i++) { + queue.emplace(huffman_nodes[i].size, i); + } + while (queue.size() >= 2) { + auto right = queue.top().second; + queue.pop(); + auto left = queue.top().second; + queue.pop(); + huffman_nodes.push_back( + {huffman_nodes[left].size + huffman_nodes[right].size, left + 1, + right + 1}); + queue.emplace(huffman_nodes.back().size, huffman_nodes.size() - 1); + } + + std::vector nodes_structure; + std::function enumerate = [&](size_t index) -> size_t { + const auto& [size, left, right] = huffman_nodes[index]; + if (left == 0 || right == 0) { + permutation_[index] = inverse_permutation_.size(); + inverse_permutation_.push_back(index); + return 1; + } + size_t ind = nodes_structure.size(), subtree = 0; + if (size > 0) { + nodes_structure.push_back(0); + } + subtree += enumerate(left - 1); + if (size > 0) { + nodes_structure[ind] = subtree; + } + subtree += enumerate(right - 1); + return subtree; + }; + + permutation_.resize(alphabet_size_); + nodes_structure.reserve(alphabet_size_); + inverse_permutation_.reserve(alphabet_size_); + enumerate(huffman_nodes.size() - 1); + + std::vector data_copy(data.begin(), data.end()); + root_ = + BuildNode(0, alphabet_size_, data_copy, WaveletNode::kNil, + [&](node_index_t node) { return nodes_structure[node]; }); } } @@ -121,6 +193,7 @@ class WaveletTree { if (symbol >= alphabet_size_) [[unlikely]] { return 0; } + symbol = permutation_[symbol]; for (node_index_t current = root_; current != WaveletNode::kNil;) { const WaveletNode& node = nodes_[current]; if (symbol < node.middle) { @@ -138,6 +211,7 @@ class WaveletTree { if (symbol >= alphabet_size_ || data_size_ == 0) [[unlikely]] { return data_size_; } + symbol = permutation_[symbol]; node_index_t current = leaves_[symbol]; for (; current != WaveletNode::kNil; current = nodes_[current].parent) { const WaveletNode& node = nodes_[current]; @@ -151,6 +225,9 @@ class WaveletTree { } std::vector getSegment(size_t begin, size_t end) const { + if (alphabet_size_ == 0 || data_size_ == 0) [[unlikely]] { + return {}; + } auto length = static_cast(end - begin); std::vector result(2 * length); copySegmentContent(root_, begin, end, diff --git a/scripts/coverage_report.sh b/scripts/coverage_report.sh index 0820856..f346112 100755 --- a/scripts/coverage_report.sh +++ b/scripts/coverage_report.sh @@ -10,6 +10,7 @@ cmake --build --preset coverage "${BUILD_DIR}/unittests" "${BUILD_DIR}/louds_tree_tests" "${BUILD_DIR}/test_rmm" +"${BUILD_DIR}/wavelet_tree_tests" cd "${BUILD_DIR}" find . -name "*.gcda" > gcov_files.txt diff --git a/src/tests/wavelet_tree_tests.cpp b/src/tests/wavelet_tree_tests.cpp index 61372c2..d578c3b 100644 --- a/src/tests/wavelet_tree_tests.cpp +++ b/src/tests/wavelet_tree_tests.cpp @@ -57,7 +57,7 @@ TEST(WaveletTreeTest, BasicSegment) { for (size_t end = begin; end <= data_size; end++) { auto segment = wavelet_tree.getSegment(begin, end); EXPECT_EQ(segment.size(), end - begin); - for(size_t i = 0; i < end - begin; i++){ + for (size_t i = 0; i < end - begin; i++) { EXPECT_EQ(segment[i], data[begin + i]); } } @@ -77,13 +77,16 @@ TEST(WaveletTreeTest, SmokeSelect) { rank[data[i]].push_back(i); } - WaveletTree wavelet_tree(alphabet_size, data); + for (auto build_type : {pixie::WaveletTreeBuildType::Standard, + pixie::WaveletTreeBuildType::Huffman}) { + WaveletTree wavelet_tree(alphabet_size, data, build_type); - for (uint64_t symb = 0; symb < alphabet_size; symb++) { - for (size_t i = 0; i <= rank[symb].size(); i++) { - uint64_t exp = i == rank[symb].size() ? data_size : rank[symb][i]; - uint64_t act = wavelet_tree.select(symb, i + 1); - EXPECT_EQ(act, exp); + for (uint64_t symb = 0; symb < alphabet_size; symb++) { + for (size_t i = 0; i <= rank[symb].size(); i++) { + uint64_t exp = i == rank[symb].size() ? data_size : rank[symb][i]; + uint64_t act = wavelet_tree.select(symb, i + 1); + EXPECT_EQ(act, exp); + } } } } @@ -99,25 +102,26 @@ TEST(WaveletTreeTest, SmokeRank) { std::vector query = generate_random_data(data_size + 1, alphabet_size, rng); - count.assign(alphabet_size, 0); + for (auto build_type : {pixie::WaveletTreeBuildType::Standard, + pixie::WaveletTreeBuildType::Huffman}) { + count.assign(alphabet_size, 0); + WaveletTree wavelet_tree(alphabet_size, data, build_type); - WaveletTree wavelet_tree(alphabet_size, data); - - for (size_t i = 0; i <= data_size; i++) { - uint64_t symb = query[i]; - uint64_t exp = count[symb]; - uint64_t act = wavelet_tree.rank(symb, i); - EXPECT_EQ(act, exp); + for (size_t i = 0; i <= data_size; i++) { + uint64_t symb = query[i]; + uint64_t exp = count[symb]; + uint64_t act = wavelet_tree.rank(symb, i); + EXPECT_EQ(act, exp); - if (i == data_size) { - break; + if (i == data_size) { + break; + } + count[data[i]]++; } - count[data[i]]++; } } } - TEST(WaveletTreeTest, SmokeSegment) { size_t data_size = 256, alphabet_size = 100; @@ -125,15 +129,18 @@ TEST(WaveletTreeTest, SmokeSegment) { std::vector data = generate_random_data(data_size, alphabet_size, rng); - WaveletTree wavelet_tree(alphabet_size, data); - - for (size_t begin = 0; begin <= data_size; begin++) { - for (size_t end = begin; end <= data_size; end++) { - auto segment = wavelet_tree.getSegment(begin, end); - EXPECT_EQ(segment.size(), end - begin); - for(size_t i = 0; i < end - begin; i++){ - EXPECT_EQ(segment[i], data[begin + i]); + for (auto build_type : {pixie::WaveletTreeBuildType::Standard, + pixie::WaveletTreeBuildType::Huffman}) { + WaveletTree wavelet_tree(alphabet_size, data, build_type); + + for (size_t begin = 0; begin <= data_size; begin++) { + for (size_t end = begin; end <= data_size; end++) { + auto segment = wavelet_tree.getSegment(begin, end); + EXPECT_EQ(segment.size(), end - begin); + for (size_t i = 0; i < end - begin; i++) { + EXPECT_EQ(segment[i], data[begin + i]); + } } } } -} \ No newline at end of file +}