diff --git a/CMakeLists.txt b/CMakeLists.txt index c91128d..f2502c3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -148,6 +148,15 @@ if(PIXIE_TESTS) gtest gtest_main ${PIXIE_DIAGNOSTICS_LIBS}) + + add_executable(wavelet_tree_tests + src/tests/wavelet_tree_tests.cpp) + target_include_directories(wavelet_tree_tests + PUBLIC include) + target_link_libraries(wavelet_tree_tests + gtest + gtest_main + ${PIXIE_DIAGNOSTICS_LIBS}) endif() # --------------------------------------------------------------------------- @@ -192,6 +201,15 @@ if(PIXIE_BENCHMARKS) benchmark_main ${PIXIE_DIAGNOSTICS_LIBS}) + add_executable(wavelet_tree_benchmarks + src/benchmarks/wavelet_tree_benchmarks.cpp) + target_include_directories(wavelet_tree_benchmarks + PUBLIC include) + target_link_libraries(wavelet_tree_benchmarks + benchmark + benchmark_main + ${PIXIE_DIAGNOSTICS_LIBS}) + add_executable(alignment_comparison src/benchmarks/alignment_comparison.cpp) target_include_directories(alignment_comparison diff --git a/include/pixie/utils.h b/include/pixie/utils.h index 24116d5..08e722b 100644 --- a/include/pixie/utils.h +++ b/include/pixie/utils.h @@ -59,6 +59,17 @@ std::vector adj_to_louds( return louds; } +std::vector generate_random_data(size_t data_size, + size_t alphabet_size, + std::mt19937_64& rng) { + std::vector data(data_size); + std::uniform_int_distribution alphabet(0, alphabet_size - 1); + for (size_t i = 0; i < data_size; i++) { + data[i] = alphabet(rng); + } + return data; +} + struct AdjListNode { size_t number; }; diff --git a/include/pixie/wavelet_tree.h b/include/pixie/wavelet_tree.h new file mode 100644 index 0000000..f9d2bda --- /dev/null +++ b/include/pixie/wavelet_tree.h @@ -0,0 +1,241 @@ +#pragma once + +#include + +#include + +namespace pixie { + +enum WaveletTreeBuildType { Standard, Huffman }; + +class WaveletTree { + using node_index_t = size_t; + struct WaveletNode { + static constexpr node_index_t kNil = + std::numeric_limits::max(); + node_index_t parent = kNil, left_child = kNil, right_child = kNil; + uint64_t middle; + std::vector bit_vector_; + BitVector data; + WaveletNode(uint64_t middle, + std::vector&& bit_vector, + size_t num_bits) + : middle(middle), + bit_vector_(std::move(bit_vector)), + data(std::span{bit_vector_}, num_bits) {} + }; + + size_t alphabet_size_, data_size_; + node_index_t root_; + std::vector nodes_; + std::vector leaves_; + std::vector permutation_, inverse_permutation_; + + node_index_t BuildNode( + size_t begin, + size_t end, + std::span data, + node_index_t parent, + const std::function& get_middle = [](auto) { + return -1ull; + }) { + if (end - begin == 1) { + leaves_[begin] = parent; + return WaveletNode::kNil; + } + if (data.empty()) { + for (size_t symb = begin; symb < end; symb++) { + leaves_[symb] = parent; + } + return WaveletNode::kNil; + } + + node_index_t result = nodes_.size(); + size_t middle = get_middle(result); + middle = begin + (middle == -1ull ? (end - begin) / 2 : middle); + std::vector bit_vector; + bit_vector.resize((data.size() + 63) / 64); + for (size_t i = 0; i < data.size(); i++) { + if (permutation_[data[i]] >= middle) { + bit_vector[i / 64] |= 1ull << (i % 64); + } + } + + nodes_.emplace_back(middle, std::move(bit_vector), data.size()); + auto cut = std::stable_partition( + data.begin(), data.end(), + [middle, this](uint64_t x) { return permutation_[x] < middle; }); + nodes_[result].parent = parent; + nodes_[result].left_child = BuildNode( + begin, middle, std::span{data.begin(), cut}, result, get_middle); + nodes_[result].right_child = + BuildNode(middle, end, std::span{cut, data.end()}, result, get_middle); + + return result; + } + + void copySegmentContent(node_index_t node, + size_t begin, + size_t end, + std::span dst, + std::span tmp) const { + if (begin == end) { + return; + } + size_t rank = nodes_[node].data.rank(begin), rank0 = begin - rank; + size_t right = nodes_[node].data.rank(end) - rank, + left = (end - begin) - right; + + if (nodes_[node].left_child == WaveletNode::kNil) { + std::fill(tmp.begin(), tmp.begin() + static_cast(left), + inverse_permutation_[nodes_[node].middle - 1]); + } else { + copySegmentContent(nodes_[node].left_child, rank0, rank0 + left, + tmp.subspan(0, left), dst.subspan(0, left)); + } + if (nodes_[node].right_child == WaveletNode::kNil) { + std::fill(tmp.begin() + static_cast(left), tmp.end(), + inverse_permutation_[nodes_[node].middle]); + } else { + copySegmentContent(nodes_[node].right_child, rank, rank + right, + tmp.subspan(left, right), dst.subspan(left, right)); + } + + size_t j = 0, k = left; + const auto& bit_vector = nodes_[node].bit_vector_; + for (size_t i = begin; i < end; i++) { + if ((bit_vector[i / 64] >> (i % 64)) & 1) { + dst[i - begin] = tmp[k++]; + } else { + dst[i - begin] = tmp[j++]; + } + } + } + + public: + WaveletTree(size_t alphabet_size, + std::span data, + WaveletTreeBuildType build_type = WaveletTreeBuildType::Standard) + : alphabet_size_(alphabet_size), + data_size_(data.size()), + leaves_(alphabet_size_, WaveletNode::kNil) { + if (alphabet_size == 0) { + root_ = WaveletNode::kNil; + return; + } + nodes_.reserve(alphabet_size_); + if (build_type == WaveletTreeBuildType::Standard) { + permutation_.resize(alphabet_size); + inverse_permutation_.resize(alphabet_size); + std::iota(permutation_.begin(), permutation_.end(), 0); + std::iota(inverse_permutation_.begin(), inverse_permutation_.end(), 0); + + std::vector data_copy(data.begin(), data.end()); + root_ = BuildNode(0, alphabet_size_, data_copy, WaveletNode::kNil); + } else { + struct Node { + size_t size, left, right; + }; + std::vector huffman_nodes(alphabet_size_, {0, 0, 0}); + for (auto symb : data) { + huffman_nodes[symb].size++; + } + + using elem_t = std::pair; + std::priority_queue, std::greater<>> queue; + for (size_t i = 0; i < alphabet_size_; i++) { + queue.emplace(huffman_nodes[i].size, i); + } + while (queue.size() >= 2) { + auto right = queue.top().second; + queue.pop(); + auto left = queue.top().second; + queue.pop(); + huffman_nodes.push_back( + {huffman_nodes[left].size + huffman_nodes[right].size, left + 1, + right + 1}); + queue.emplace(huffman_nodes.back().size, huffman_nodes.size() - 1); + } + + std::vector nodes_structure; + std::function enumerate = [&](size_t index) -> size_t { + const auto& [size, left, right] = huffman_nodes[index]; + if (left == 0 || right == 0) { + permutation_[index] = inverse_permutation_.size(); + inverse_permutation_.push_back(index); + return 1; + } + size_t ind = nodes_structure.size(), subtree = 0; + if (size > 0) { + nodes_structure.push_back(0); + } + subtree += enumerate(left - 1); + if (size > 0) { + nodes_structure[ind] = subtree; + } + subtree += enumerate(right - 1); + return subtree; + }; + + permutation_.resize(alphabet_size_); + nodes_structure.reserve(alphabet_size_); + inverse_permutation_.reserve(alphabet_size_); + enumerate(huffman_nodes.size() - 1); + + std::vector data_copy(data.begin(), data.end()); + root_ = + BuildNode(0, alphabet_size_, data_copy, WaveletNode::kNil, + [&](node_index_t node) { return nodes_structure[node]; }); + } + } + + size_t rank(uint64_t symbol, size_t pos) const { + if (symbol >= alphabet_size_) [[unlikely]] { + return 0; + } + symbol = permutation_[symbol]; + for (node_index_t current = root_; current != WaveletNode::kNil;) { + const WaveletNode& node = nodes_[current]; + if (symbol < node.middle) { + pos = node.data.rank0(pos); + current = node.left_child; + } else { + pos = node.data.rank(pos); + current = node.right_child; + } + } + return pos; + } + + size_t select(uint64_t symbol, size_t rank) const { + if (symbol >= alphabet_size_ || data_size_ == 0) [[unlikely]] { + return data_size_; + } + symbol = permutation_[symbol]; + node_index_t current = leaves_[symbol]; + for (; current != WaveletNode::kNil; current = nodes_[current].parent) { + const WaveletNode& node = nodes_[current]; + if (symbol < node.middle) { + rank = node.data.select0(rank) + 1; + } else { + rank = node.data.select(rank) + 1; + } + } + return rank - 1; + } + + std::vector getSegment(size_t begin, size_t end) const { + if (alphabet_size_ == 0 || data_size_ == 0) [[unlikely]] { + return {}; + } + auto length = static_cast(end - begin); + std::vector result(2 * length); + copySegmentContent(root_, begin, end, + std::span{result.begin(), result.begin() + length}, + std::span{result.begin() + length, result.end()}); + result.resize(length); + return result; + } +}; + +} // namespace pixie diff --git a/scripts/coverage_report.sh b/scripts/coverage_report.sh index 0820856..f346112 100755 --- a/scripts/coverage_report.sh +++ b/scripts/coverage_report.sh @@ -10,6 +10,7 @@ cmake --build --preset coverage "${BUILD_DIR}/unittests" "${BUILD_DIR}/louds_tree_tests" "${BUILD_DIR}/test_rmm" +"${BUILD_DIR}/wavelet_tree_tests" cd "${BUILD_DIR}" find . -name "*.gcda" > gcov_files.txt diff --git a/src/benchmarks/wavelet_tree_benchmarks.cpp b/src/benchmarks/wavelet_tree_benchmarks.cpp new file mode 100644 index 0000000..b25d175 --- /dev/null +++ b/src/benchmarks/wavelet_tree_benchmarks.cpp @@ -0,0 +1,89 @@ +#include +#include +#include + +#include + +using pixie::WaveletTree; + +static void BM_WaveletTreeSelect(benchmark::State& state) { + size_t data_size = state.range(0), alphabet_size = 1024, query = data_size; + std::mt19937_64 rng(239); + + for (auto _ : state) { + state.PauseTiming(); + + std::vector data = + generate_random_data(data_size, alphabet_size, rng); + std::vector query_symbol = + generate_random_data(query, alphabet_size, rng); + std::vector count(alphabet_size), query_pos(query); + for (auto symb : data) { + count[symb]++; + } + for (size_t i = 0; i < query; i++) { + query_pos[i] = 1 + std::uniform_int_distribution( + 0, count[query_symbol[i]])(rng); + } + + state.ResumeTiming(); + + WaveletTree wavelet_tree(alphabet_size, data); + benchmark::DoNotOptimize(wavelet_tree); + + for (size_t i = 0; i < query; i++) { + size_t select = wavelet_tree.select(query_symbol[i], query_pos[i]); + benchmark::DoNotOptimize(select); + } + } +} + +static void BM_WaveletTreeRank(benchmark::State& state) { + size_t data_size = state.range(0), alphabet_size = 1024, query = data_size; + std::mt19937_64 rng(239); + + for (auto _ : state) { + state.PauseTiming(); + + std::vector data = + generate_random_data(data_size, alphabet_size, rng); + std::vector query_symbol = + generate_random_data(query, alphabet_size, rng), + query_pos = + generate_random_data(query, data_size + 1, rng); + + state.ResumeTiming(); + + WaveletTree wavelet_tree(alphabet_size, data); + benchmark::DoNotOptimize(wavelet_tree); + + for (size_t i = 0; i < query; i++) { + size_t rank = wavelet_tree.rank(query_symbol[i], query_pos[i]); + benchmark::DoNotOptimize(rank); + } + } +} + +BENCHMARK(BM_WaveletTreeSelect) + ->ArgNames({"data_size"}) + ->RangeMultiplier(2) + ->Range(1ull << 8, 1ull << 18) + ->Iterations(100); + +BENCHMARK(BM_WaveletTreeSelect) + ->ArgNames({"data_size"}) + ->RangeMultiplier(2) + ->Range(1ull << 18, 1ull << 26) + ->Iterations(10); + +BENCHMARK(BM_WaveletTreeRank) + ->ArgNames({"data_size"}) + ->RangeMultiplier(2) + ->Range(1ull << 8, 1ull << 18) + ->Iterations(100); + +BENCHMARK(BM_WaveletTreeRank) + ->ArgNames({"data_size"}) + ->RangeMultiplier(2) + ->Range(1ull << 18, 1ull << 26) + ->Iterations(10); diff --git a/src/tests/wavelet_tree_tests.cpp b/src/tests/wavelet_tree_tests.cpp new file mode 100644 index 0000000..d578c3b --- /dev/null +++ b/src/tests/wavelet_tree_tests.cpp @@ -0,0 +1,146 @@ +#include +#include +#include + +#include + +using pixie::WaveletTree; + +TEST(WaveletTreeTest, BasicSelect) { + const std::vector data = {3, 2, 0, 3, 1, 1, 2}; + size_t data_size = 7, alphabet_size = 4; + + std::vector> rank(alphabet_size); + for (size_t i = 0; i < data_size; i++) { + rank[data[i]].push_back(i); + } + + WaveletTree wavelet_tree(alphabet_size, data); + + for (uint64_t symb = 0; symb < alphabet_size; symb++) { + for (size_t i = 0; i <= rank[symb].size(); i++) { + uint64_t exp = i == rank[symb].size() ? data_size : rank[symb][i]; + uint64_t act = wavelet_tree.select(symb, i + 1); + EXPECT_EQ(act, exp); + } + } +} + +TEST(WaveletTreeTest, BasicRank) { + const std::vector data = {3, 2, 0, 3, 1, 1, 2}; + size_t data_size = 7, alphabet_size = 4; + + std::vector count(alphabet_size); + + WaveletTree wavelet_tree(alphabet_size, data); + for (size_t i = 0; i <= data_size; i++) { + for (uint64_t symb = 0; symb < alphabet_size; symb++) { + uint64_t exp = count[symb]; + uint64_t act = wavelet_tree.rank(symb, i); + EXPECT_EQ(act, exp); + } + + if (i == data_size) { + break; + } + count[data[i]]++; + } +} + +TEST(WaveletTreeTest, BasicSegment) { + const std::vector data = {3, 2, 0, 3, 1, 1, 2}; + size_t data_size = 7, alphabet_size = 4; + + WaveletTree wavelet_tree(alphabet_size, data); + + for (size_t begin = 0; begin <= data_size; begin++) { + for (size_t end = begin; end <= data_size; end++) { + auto segment = wavelet_tree.getSegment(begin, end); + EXPECT_EQ(segment.size(), end - begin); + for (size_t i = 0; i < end - begin; i++) { + EXPECT_EQ(segment[i], data[begin + i]); + } + } + } +} + +TEST(WaveletTreeTest, SmokeSelect) { + std::vector> rank; + for (size_t data_size = 8; data_size < (1 << 22); data_size <<= 1) { + size_t alphabet_size = 1024; + std::mt19937_64 rng(239); + std::vector data = + generate_random_data(data_size, alphabet_size, rng); + + rank.assign(alphabet_size, {}); + for (size_t i = 0; i < data_size; i++) { + rank[data[i]].push_back(i); + } + + for (auto build_type : {pixie::WaveletTreeBuildType::Standard, + pixie::WaveletTreeBuildType::Huffman}) { + WaveletTree wavelet_tree(alphabet_size, data, build_type); + + for (uint64_t symb = 0; symb < alphabet_size; symb++) { + for (size_t i = 0; i <= rank[symb].size(); i++) { + uint64_t exp = i == rank[symb].size() ? data_size : rank[symb][i]; + uint64_t act = wavelet_tree.select(symb, i + 1); + EXPECT_EQ(act, exp); + } + } + } + } +} + +TEST(WaveletTreeTest, SmokeRank) { + std::vector count; + for (size_t data_size = 8; data_size < (1 << 22); data_size <<= 1) { + size_t alphabet_size = 1024; + std::mt19937_64 rng(239); + std::vector data = + generate_random_data(data_size, alphabet_size, rng); + std::vector query = + generate_random_data(data_size + 1, alphabet_size, rng); + + for (auto build_type : {pixie::WaveletTreeBuildType::Standard, + pixie::WaveletTreeBuildType::Huffman}) { + count.assign(alphabet_size, 0); + WaveletTree wavelet_tree(alphabet_size, data, build_type); + + for (size_t i = 0; i <= data_size; i++) { + uint64_t symb = query[i]; + uint64_t exp = count[symb]; + uint64_t act = wavelet_tree.rank(symb, i); + EXPECT_EQ(act, exp); + + if (i == data_size) { + break; + } + count[data[i]]++; + } + } + } +} + +TEST(WaveletTreeTest, SmokeSegment) { + size_t data_size = 256, alphabet_size = 100; + + std::mt19937_64 rng(239); + std::vector data = + generate_random_data(data_size, alphabet_size, rng); + + for (auto build_type : {pixie::WaveletTreeBuildType::Standard, + pixie::WaveletTreeBuildType::Huffman}) { + WaveletTree wavelet_tree(alphabet_size, data, build_type); + + for (size_t begin = 0; begin <= data_size; begin++) { + for (size_t end = begin; end <= data_size; end++) { + auto segment = wavelet_tree.getSegment(begin, end); + EXPECT_EQ(segment.size(), end - begin); + for (size_t i = 0; i < end - begin; i++) { + EXPECT_EQ(segment[i], data[begin + i]); + } + } + } + } +}