Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ if(PIXIE_TESTS)
gtest
gtest_main
${PIXIE_DIAGNOSTICS_LIBS})

add_executable(wavelet_tree_tests
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add in the converage_report.sh and build-test workflow

src/tests/wavelet_tree_tests.cpp)
target_include_directories(wavelet_tree_tests
PUBLIC include)
target_link_libraries(wavelet_tree_tests
gtest
gtest_main
${PIXIE_DIAGNOSTICS_LIBS})
endif()

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -192,6 +201,15 @@ if(PIXIE_BENCHMARKS)
benchmark_main
${PIXIE_DIAGNOSTICS_LIBS})

add_executable(wavelet_tree_benchmarks
src/benchmarks/wavelet_tree_benchmarks.cpp)
target_include_directories(wavelet_tree_benchmarks
PUBLIC include)
target_link_libraries(wavelet_tree_benchmarks
benchmark
benchmark_main
${PIXIE_DIAGNOSTICS_LIBS})

add_executable(alignment_comparison
src/benchmarks/alignment_comparison.cpp)
target_include_directories(alignment_comparison
Expand Down
11 changes: 11 additions & 0 deletions include/pixie/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,17 @@ std::vector<uint64_t> adj_to_louds(
return louds;
}

std::vector<uint64_t> generate_random_data(size_t data_size,
size_t alphabet_size,
std::mt19937_64& rng) {
std::vector<uint64_t> data(data_size);
std::uniform_int_distribution<uint64_t> alphabet(0, alphabet_size - 1);
Comment thread
Malkovsky marked this conversation as resolved.
for (size_t i = 0; i < data_size; i++) {
data[i] = alphabet(rng);
}
return data;
}

struct AdjListNode {
size_t number;
};
Expand Down
241 changes: 241 additions & 0 deletions include/pixie/wavelet_tree.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
#pragma once

#include <pixie/bitvector.h>

#include <limits>

namespace pixie {

enum WaveletTreeBuildType { Standard, Huffman };
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to use enum class. In current setup it will conflict with other enums using the same values, with enum class it will not be an issue, you'd need to specify as WaveletTreeBuildType::Standard though. Also please use all uppercase for enums, i.e. STANDARD, HUFFMAN


class WaveletTree {
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring description is required for class and public methods

using node_index_t = size_t;
struct WaveletNode {
static constexpr node_index_t kNil =
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use npos for consistency

std::numeric_limits<node_index_t>::max();
node_index_t parent = kNil, left_child = kNil, right_child = kNil;
uint64_t middle;
std::vector<uint64_t> bit_vector_;
BitVector data;
WaveletNode(uint64_t middle,
std::vector<uint64_t>&& bit_vector,
size_t num_bits)
: middle(middle),
bit_vector_(std::move(bit_vector)),
data(std::span{bit_vector_}, num_bits) {}
};

size_t alphabet_size_, data_size_;
node_index_t root_;
std::vector<WaveletNode> nodes_;
std::vector<node_index_t> leaves_;
std::vector<size_t> permutation_, inverse_permutation_;

node_index_t BuildNode(
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

snake_case for methods, i.e. build_node

size_t begin,
size_t end,
std::span<uint64_t> data,
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the node is not suppose to change data than it should have type std::span<const uint64_t>

node_index_t parent,
const std::function<size_t(node_index_t)>& get_middle = [](auto) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SUGGESTION: Avoid std::function in the hot build path

std::function introduces type-erasure overhead per recursive call. If you can make BuildNode a template on the callable (or pass a small function object by value), you can keep the same behavior while eliminating indirect calls and potential allocations during construction.

return -1ull;
}) {
if (end - begin == 1) {
leaves_[begin] = parent;
return WaveletNode::kNil;
}
if (data.empty()) {
for (size_t symb = begin; symb < end; symb++) {
leaves_[symb] = parent;
}
return WaveletNode::kNil;
}

node_index_t result = nodes_.size();
size_t middle = get_middle(result);
middle = begin + (middle == -1ull ? (end - begin) / 2 : middle);
std::vector<uint64_t> bit_vector;
bit_vector.resize((data.size() + 63) / 64);
for (size_t i = 0; i < data.size(); i++) {
if (permutation_[data[i]] >= middle) {
bit_vector[i / 64] |= 1ull << (i % 64);
}
}

nodes_.emplace_back(middle, std::move(bit_vector), data.size());
auto cut = std::stable_partition(
Comment thread
Malkovsky marked this conversation as resolved.
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check how SDSL constructs WT: it does it in two passes over the data without intermediate buffers. All bitvectors are concotenated (optional improvement)
https://app.kilo.ai/s/605b1323-63b1-430e-be01-69dda2718f71

data.begin(), data.end(),
[middle, this](uint64_t x) { return permutation_[x] < middle; });
nodes_[result].parent = parent;
nodes_[result].left_child = BuildNode(
begin, middle, std::span{data.begin(), cut}, result, get_middle);
nodes_[result].right_child =
BuildNode(middle, end, std::span{cut, data.end()}, result, get_middle);

return result;
}

void copySegmentContent(node_index_t node,
size_t begin,
size_t end,
std::span<uint64_t> dst,
std::span<uint64_t> tmp) const {
if (begin == end) {
return;
}
size_t rank = nodes_[node].data.rank(begin), rank0 = begin - rank;
size_t right = nodes_[node].data.rank(end) - rank,
left = (end - begin) - right;

if (nodes_[node].left_child == WaveletNode::kNil) {
std::fill(tmp.begin(), tmp.begin() + static_cast<long long>(left),
inverse_permutation_[nodes_[node].middle - 1]);
} else {
copySegmentContent(nodes_[node].left_child, rank0, rank0 + left,
tmp.subspan(0, left), dst.subspan(0, left));
}
if (nodes_[node].right_child == WaveletNode::kNil) {
std::fill(tmp.begin() + static_cast<long long>(left), tmp.end(),
inverse_permutation_[nodes_[node].middle]);
} else {
copySegmentContent(nodes_[node].right_child, rank, rank + right,
tmp.subspan(left, right), dst.subspan(left, right));
}

size_t j = 0, k = left;
const auto& bit_vector = nodes_[node].bit_vector_;
for (size_t i = begin; i < end; i++) {
if ((bit_vector[i / 64] >> (i % 64)) & 1) {
dst[i - begin] = tmp[k++];
} else {
dst[i - begin] = tmp[j++];
}
}
}

public:
WaveletTree(size_t alphabet_size,
std::span<const uint64_t> data,
WaveletTreeBuildType build_type = WaveletTreeBuildType::Standard)
: alphabet_size_(alphabet_size),
data_size_(data.size()),
leaves_(alphabet_size_, WaveletNode::kNil) {
if (alphabet_size == 0) {
root_ = WaveletNode::kNil;
return;
}
nodes_.reserve(alphabet_size_);
if (build_type == WaveletTreeBuildType::Standard) {
permutation_.resize(alphabet_size);
inverse_permutation_.resize(alphabet_size);
std::iota(permutation_.begin(), permutation_.end(), 0);
std::iota(inverse_permutation_.begin(), inverse_permutation_.end(), 0);

std::vector<uint64_t> data_copy(data.begin(), data.end());
root_ = BuildNode(0, alphabet_size_, data_copy, WaveletNode::kNil);
} else {
struct Node {
size_t size, left, right;
};
std::vector<Node> huffman_nodes(alphabet_size_, {0, 0, 0});
for (auto symb : data) {
huffman_nodes[symb].size++;
}

using elem_t = std::pair<size_t, size_t>;
std::priority_queue<elem_t, std::vector<elem_t>, std::greater<>> queue;
for (size_t i = 0; i < alphabet_size_; i++) {
queue.emplace(huffman_nodes[i].size, i);
}
while (queue.size() >= 2) {
auto right = queue.top().second;
queue.pop();
auto left = queue.top().second;
queue.pop();
huffman_nodes.push_back(
{huffman_nodes[left].size + huffman_nodes[right].size, left + 1,
right + 1});
queue.emplace(huffman_nodes.back().size, huffman_nodes.size() - 1);
}

std::vector<size_t> nodes_structure;
std::function<size_t(size_t)> enumerate = [&](size_t index) -> size_t {
const auto& [size, left, right] = huffman_nodes[index];
if (left == 0 || right == 0) {
permutation_[index] = inverse_permutation_.size();
inverse_permutation_.push_back(index);
return 1;
}
size_t ind = nodes_structure.size(), subtree = 0;
if (size > 0) {
nodes_structure.push_back(0);
}
subtree += enumerate(left - 1);
if (size > 0) {
nodes_structure[ind] = subtree;
}
subtree += enumerate(right - 1);
return subtree;
};

permutation_.resize(alphabet_size_);
nodes_structure.reserve(alphabet_size_);
inverse_permutation_.reserve(alphabet_size_);
enumerate(huffman_nodes.size() - 1);

std::vector<uint64_t> data_copy(data.begin(), data.end());
root_ =
BuildNode(0, alphabet_size_, data_copy, WaveletNode::kNil,
[&](node_index_t node) { return nodes_structure[node]; });
}
}

size_t rank(uint64_t symbol, size_t pos) const {
if (symbol >= alphabet_size_) [[unlikely]] {
return 0;
}
symbol = permutation_[symbol];
for (node_index_t current = root_; current != WaveletNode::kNil;) {
const WaveletNode& node = nodes_[current];
if (symbol < node.middle) {
pos = node.data.rank0(pos);
current = node.left_child;
} else {
pos = node.data.rank(pos);
current = node.right_child;
}
}
return pos;
}

size_t select(uint64_t symbol, size_t rank) const {
if (symbol >= alphabet_size_ || data_size_ == 0) [[unlikely]] {
return data_size_;
}
symbol = permutation_[symbol];
node_index_t current = leaves_[symbol];
for (; current != WaveletNode::kNil; current = nodes_[current].parent) {
const WaveletNode& node = nodes_[current];
if (symbol < node.middle) {
rank = node.data.select0(rank) + 1;
} else {
rank = node.data.select(rank) + 1;
}
}
return rank - 1;
}

std::vector<uint64_t> getSegment(size_t begin, size_t end) const {
Comment thread
Malkovsky marked this conversation as resolved.
if (alphabet_size_ == 0 || data_size_ == 0) [[unlikely]] {
return {};
}
auto length = static_cast<long long>(end - begin);
std::vector<uint64_t> result(2 * length);
copySegmentContent(root_, begin, end,
std::span{result.begin(), result.begin() + length},
std::span{result.begin() + length, result.end()});
result.resize(length);
return result;
}
};

} // namespace pixie
1 change: 1 addition & 0 deletions scripts/coverage_report.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ cmake --build --preset coverage
"${BUILD_DIR}/unittests"
"${BUILD_DIR}/louds_tree_tests"
"${BUILD_DIR}/test_rmm"
"${BUILD_DIR}/wavelet_tree_tests"

cd "${BUILD_DIR}"
find . -name "*.gcda" > gcov_files.txt
Expand Down
89 changes: 89 additions & 0 deletions src/benchmarks/wavelet_tree_benchmarks.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#include <benchmark/benchmark.h>
#include <pixie/utils.h>
#include <pixie/wavelet_tree.h>

#include <random>

using pixie::WaveletTree;

static void BM_WaveletTreeSelect(benchmark::State& state) {
size_t data_size = state.range(0), alphabet_size = 1024, query = data_size;
std::mt19937_64 rng(239);

for (auto _ : state) {
state.PauseTiming();

std::vector<uint64_t> data =
generate_random_data(data_size, alphabet_size, rng);
std::vector<uint64_t> query_symbol =
generate_random_data(query, alphabet_size, rng);
std::vector<size_t> count(alphabet_size), query_pos(query);
for (auto symb : data) {
count[symb]++;
}
for (size_t i = 0; i < query; i++) {
query_pos[i] = 1 + std::uniform_int_distribution<size_t>(
0, count[query_symbol[i]])(rng);
}

state.ResumeTiming();

WaveletTree wavelet_tree(alphabet_size, data);
benchmark::DoNotOptimize(wavelet_tree);

for (size_t i = 0; i < query; i++) {
size_t select = wavelet_tree.select(query_symbol[i], query_pos[i]);
benchmark::DoNotOptimize(select);
}
}
}

static void BM_WaveletTreeRank(benchmark::State& state) {
size_t data_size = state.range(0), alphabet_size = 1024, query = data_size;
std::mt19937_64 rng(239);

for (auto _ : state) {
state.PauseTiming();

std::vector<uint64_t> data =
generate_random_data(data_size, alphabet_size, rng);
std::vector<uint64_t> query_symbol =
generate_random_data(query, alphabet_size, rng),
query_pos =
generate_random_data(query, data_size + 1, rng);

state.ResumeTiming();

WaveletTree wavelet_tree(alphabet_size, data);
benchmark::DoNotOptimize(wavelet_tree);

for (size_t i = 0; i < query; i++) {
size_t rank = wavelet_tree.rank(query_symbol[i], query_pos[i]);
benchmark::DoNotOptimize(rank);
}
}
}

BENCHMARK(BM_WaveletTreeSelect)
->ArgNames({"data_size"})
->RangeMultiplier(2)
->Range(1ull << 8, 1ull << 18)
->Iterations(100);

BENCHMARK(BM_WaveletTreeSelect)
->ArgNames({"data_size"})
->RangeMultiplier(2)
->Range(1ull << 18, 1ull << 26)
->Iterations(10);

BENCHMARK(BM_WaveletTreeRank)
->ArgNames({"data_size"})
->RangeMultiplier(2)
->Range(1ull << 8, 1ull << 18)
->Iterations(100);

BENCHMARK(BM_WaveletTreeRank)
->ArgNames({"data_size"})
->RangeMultiplier(2)
->Range(1ull << 18, 1ull << 26)
->Iterations(10);
Loading
Loading