diff --git a/.github/workflows/kilo-dispatch.yml b/.github/workflows/kilo-dispatch.yml index 6c0e6c8..9585ce5 100644 --- a/.github/workflows/kilo-dispatch.yml +++ b/.github/workflows/kilo-dispatch.yml @@ -21,7 +21,7 @@ on: timeout_minutes: description: Timeout in minutes for job and Kilo run required: false - default: "30" + default: "120" type: string model: description: Gateway model ID for kilocode provider @@ -57,6 +57,11 @@ jobs: steps: - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Fetch remote branches for baseline resolution + run: git fetch --no-tags origin +refs/heads/*:refs/remotes/origin/* - name: Setup Node.js uses: actions/setup-node@v4 @@ -66,6 +71,21 @@ jobs: - name: Install Kilo CLI run: npm install -g @kilocode/cli + - name: Install performance dependencies + if: ${{ inputs.command == 'perf-review' }} + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends \ + build-essential cmake ninja-build pkg-config \ + python3 python3-pip \ + python3-numpy python3-scipy \ + libpfm4 libpfm4-dev + cmake --version + python3 --version + python3 -c "import numpy, scipy; print(f'numpy {numpy.__version__}'); print(f'scipy {scipy.__version__}')" + dpkg-query -W -f='${Package} ${Version}\n' libpfm4 libpfm4-dev + test -f /usr/include/perfmon/pfmlib.h + - name: Verify KILO_API_TOKEN secret run: | if [ -z "${{ secrets.KILO_API_TOKEN }}" ]; then @@ -90,6 +110,85 @@ jobs: owner: ${{ github.repository_owner }} repositories: ${{ github.event.repository.name }} + - name: Post start status as GitHub App comment + if: ${{ always() && inputs.pr_number != '' && steps.app-token.outputs.token != '' }} + uses: actions/github-script@v7 + with: + github-token: ${{ steps.app-token.outputs.token }} + script: | + const prNumber = Number(${{ toJSON(inputs.pr_number) }}); + const command = ${{ toJSON(inputs.command) }}; + const commandArgs = ${{ toJSON(inputs.command_args) }}; + const prompt = ${{ toJSON(inputs.prompt) }}; + const parentCommentIdRaw = ((${{ toJSON(inputs.parent_comment_id) }} || ${{ toJSON(inputs.review_comment_id) }} || '') + '').trim(); + const parentCommentId = parentCommentIdRaw ? Number(parentCommentIdRaw) : null; + const runUrl = `${process.env.GITHUB_SERVER_URL}/${context.repo.owner}/${context.repo.repo}/actions/runs/${process.env.GITHUB_RUN_ID}`; + + let requestSummary = ''; + if ((command || '').trim()) { + requestSummary = `/${command}${(commandArgs || '').trim() ? ` ${commandArgs}` : ''}`; + } else { + requestSummary = (prompt || '').trim(); + } + + requestSummary = requestSummary.replace(/\s+/g, ' ').trim(); + if (!requestSummary) { + requestSummary = '(no request text provided)'; + } + if (requestSummary.length > 180) { + requestSummary = `${requestSummary.slice(0, 177)}...`; + } + + const body = `Started Kilo run: ${requestSummary}\nRun: ${runUrl}`; + + if (parentCommentId) { + try { + await github.rest.pulls.createReplyForReviewComment({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: prNumber, + comment_id: parentCommentId, + body + }); + return; + } catch (error) { + const status = error?.status; + if (status !== 404 && status !== 422) { + throw error; + } + + let prefix = ''; + try { + const parent = await github.rest.issues.getComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: parentCommentId + }); + const login = parent?.data?.user?.login; + if (login) { + prefix = `@${login} `; + } + } catch { + // Fallback to plain issue comment if parent lookup fails. + } + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: prNumber, + body: `${prefix}${body}` + }); + return; + } + } + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: prNumber, + body + }); + - name: Run Kilo CLI env: KILO_API_TOKEN: ${{ secrets.KILO_API_TOKEN }} @@ -129,13 +228,23 @@ jobs: "${kilo_args[@]}" 2>&1 | tee -a kilo-events.log | tee -a kilo-run.log fi - node -e "const fs=require('fs');const strip=(s)=>s.replace(/\x1B\[[0-9;]*[A-Za-z]/g,'');const lines=fs.readFileSync('kilo-events.log','utf8').split(/\r?\n/);const texts=[];for(const line of lines){if(!line.trim())continue;try{const evt=JSON.parse(line);if(evt&&evt.type==='text'&&evt.part&&typeof evt.part.text==='string'){const t=evt.part.text.trim();if(t)texts.push(t);}}catch{}}let out='';if(texts.length){out=texts[texts.length-1];}else{const fallback=strip(fs.readFileSync('kilo-events.log','utf8')).split(/\r?\n/).map(x=>x.trim()).filter(Boolean);out=fallback.length?fallback[fallback.length-1]:'';}fs.writeFileSync('kilo-output.log',out+(out.endsWith('\n')?'':'\n'));" + node --experimental-strip-types scripts/kilo-postprocess.ts kilo-events.log kilo-output.log kilo-readable.log - - name: Post result as GitHub App comment + - name: Create GitHub App token for final comment if: ${{ always() && inputs.pr_number != '' }} + id: app-token-final + uses: actions/create-github-app-token@v1 + with: + app-id: ${{ secrets.NIKOLAY_REVIEWER_APP_ID }} + private-key: ${{ secrets.NIKOLAY_REVIEWER_PRIVATE_KEY }} + owner: ${{ github.repository_owner }} + repositories: ${{ github.event.repository.name }} + + - name: Post result as GitHub App comment + if: ${{ always() && inputs.pr_number != '' && steps.app-token-final.outputs.token != '' }} uses: actions/github-script@v7 with: - github-token: ${{ steps.app-token.outputs.token }} + github-token: ${{ steps.app-token-final.outputs.token }} script: | const fs = require('fs'); const prNumber = Number('${{ inputs.pr_number }}'); @@ -203,4 +312,6 @@ jobs: uses: actions/upload-artifact@v4 with: name: kilo-run-log - path: kilo-run.log + path: | + kilo-run.log + kilo-readable.log diff --git a/.kilo/command/perf-review.md b/.kilo/command/perf-review.md index 3da280a..3fff142 100644 --- a/.kilo/command/perf-review.md +++ b/.kilo/command/perf-review.md @@ -19,7 +19,10 @@ Non-negotiable requirements: If arguments are omitted: - Default target branch to PR base branch from `gh pr view --json baseRefName` when available. - Fall back target branch to `main`. -- Default filter to empty (run full selected benchmark suites). +- Default filter must be **targeted**, not full-suite: + - Derive from changed files and changed symbols. + - If `include/pixie/bitvector.h` changed in select path, default to `BM_Select` and add `BM_RankNonInterleaved` as control. + - Run full selected suites only as last resort when mapping fails. ## Step 1 - Resolve Branches and Revisions @@ -43,7 +46,7 @@ Map file paths to benchmark binaries: | Changed path pattern | Benchmark binary | Coverage | |---|---|---| -| `include/bit_vector*`, `include/interleaved*` | `benchmarks` | BitVector rank/select | +| `include/pixie/bitvector*`, `include/*bit_vector*`, `include/interleaved*` | `benchmarks` | BitVector rank/select | | `include/rmm*` | `bench_rmm` | RmM tree operations | | `include/louds*` | `louds_tree_benchmarks` | LOUDS traversal | | `include/simd*`, `include/aligned*` | `alignment_comparison` | SIMD and alignment | @@ -57,10 +60,15 @@ Available benchmark binaries: - `louds_tree_benchmarks` - `alignment_comparison` -If the mapping is ambiguous, run all benchmark binaries. +If the mapping is ambiguous, run all benchmark binaries but still apply a focused filter first. If `--filter` is provided, pass it through as `--benchmark_filter`. Print selected binaries and why they were selected. +Execution guardrails: +- Do not use background jobs (`nohup`, `&`) for benchmark runs in CI. +- Do not interleave multiple benchmark runs into one shell command stream. +- Run one benchmark command at a time and wait for completion. + ## Step 3 - Build Both Revisions (Timing and Profiling Builds) Use isolated build directories per short hash. @@ -98,19 +106,38 @@ If a required binary is missing, report failure and stop with a blocked verdict. ## Step 5 - Run Timing Comparison (Primary Judgment) -Locate compare script from baseline timing build: - -`build/benchmarks-all_bench_/_deps/googlebenchmark-src/tools/compare.py` - -For each selected benchmark binary, run: - -`python3 benchmarks [--benchmark_filter=""]` - -Capture full output for each binary and keep it for report details. +Use a deterministic JSON-first workflow. Do not rely on long-running `compare.py` binary-vs-binary mode. + +1. Verify Python benchmark tooling once before runs: + - `python3 -c "import numpy, scipy"` +2. For each selected benchmark binary, run baseline then contender sequentially, each with explicit JSON out: + - `--benchmark_filter=""` + - `--benchmark_format=json` + - `--benchmark_out=.json` + - `--benchmark_report_aggregates_only=true` + - `--benchmark_display_aggregates_only=true` +3. Suppress benchmark stdout/stderr noise when generating JSON artifacts so files stay valid: + - `> .log 2>&1` +4. Validate both JSON files before comparison: + - `python3 -m json.tool .json > /dev/null` +5. Compare using one of: + - `python3 -a benchmarks ` + - or a deterministic local Python diff script over aggregate means. +6. Keep raw JSON files and comparison output for auditability. + +Timeout and retry policy: +- Use command timeouts that match benchmark scope. +- If a run times out once, narrow filter immediately and retry once. +- Maximum retry count per benchmark group: 1. +- If still timing out, produce a blocked/partial verdict with explicit scope limitations. ## Step 6 - Collect Hardware Counter Profiles (Linux Only) -If Linux profiling build is available, run both baseline and contender diagnostic binaries with counter output: +Run a preflight first to avoid wasted attempts: +1. Execute one tiny benchmark with perf counters (e.g. one benchmark case) and inspect output for counter availability. +2. If output includes warnings like `Failed to get a file descriptor for performance counter`, mark counters unavailable and skip counter collection. + +If preflight passes and Linux profiling build is available, run both baseline and contender diagnostic binaries with counter output: - `--benchmark_counters_tabular=true` - `--benchmark_format=json` @@ -128,7 +155,7 @@ Compute derived metrics when denominators are non-zero: - Cache miss rate = cache-misses / cache-references - Branch mispredict rate = branch-misses / branches -If profiling is unavailable (non-Linux or libpfm not available), continue with timing-only review and explicitly mark profiling as unavailable in the report. +If profiling is unavailable (non-Linux, libpfm missing, or perf permissions blocked), continue with timing-only review and explicitly mark profiling as unavailable in the report. ## Step 7 - Analyze Timing and Counter Data @@ -150,6 +177,10 @@ Judgment priority: - Base verdict primarily on benchmark timing comparison. - Use counter data as explanatory evidence and confidence signal. +Noise-control expectations: +- Include at least one control benchmark family expected to be unaffected by the code change. +- Treat isolated swings without pattern as noise unless reproduced across related sizes/fill ratios. + ## Step 8 - Produce Final Markdown Report Return a structured markdown report with this shape: @@ -197,3 +228,4 @@ Verdict rules: - If required builds fail or timing comparison cannot run, output a blocked review with exact failure points and no misleading verdict. - If only profiling fails, continue with timing-based verdict and explicitly list profiling limitation. +- If JSON output is invalid/truncated, discard it and rerun that benchmark command once with tighter filter and explicit output redirection. diff --git a/.kilo/skills/benchmarks-compare-revisions/SKILL.md b/.kilo/skills/benchmarks-compare-revisions/SKILL.md index ae8d7e7..23996a5 100644 --- a/.kilo/skills/benchmarks-compare-revisions/SKILL.md +++ b/.kilo/skills/benchmarks-compare-revisions/SKILL.md @@ -39,30 +39,63 @@ git checkout ${CONTENDER} ## Step 2 — Compare using compare.py -Use Google Benchmark’s compare.py to run both binaries and compute a statistical comparison. +Use Google Benchmark compare tooling with a JSON-first flow to avoid long-running binary-vs-binary retries. Locate compare.py from the Google Benchmark dependency (installed under the build tree): ```bash COMPARE_PY=build/benchmarks-all_bench_${BASELINE}/_deps/googlebenchmark-src/tools/compare.py ``` -Run the comparison (benchmarks mode): +Verify Python deps once (compare.py imports numpy/scipy): ```bash -python3 ${COMPARE_PY} benchmarks \ - build/benchmarks-all_bench_${BASELINE}/benchmarks \ - build/benchmarks-all_bench_${CONTENDER}/benchmarks +python3 -c "import numpy, scipy" ``` -### Optional: filter to reduce noise +Generate baseline/contender JSON sequentially with explicit file outputs: +```bash +BASE_JSON=/tmp/bench_${BASELINE}.json +CONT_JSON=/tmp/bench_${CONTENDER}.json + +build/benchmarks-all_bench_${BASELINE}/benchmarks \ + --benchmark_report_aggregates_only=true \ + --benchmark_display_aggregates_only=true \ + --benchmark_format=json \ + --benchmark_out=${BASE_JSON} > /tmp/bench_${BASELINE}.log 2>&1 + +build/benchmarks-all_bench_${CONTENDER}/benchmarks \ + --benchmark_report_aggregates_only=true \ + --benchmark_display_aggregates_only=true \ + --benchmark_format=json \ + --benchmark_out=${CONT_JSON} > /tmp/bench_${CONTENDER}.log 2>&1 +``` + +Validate JSON before comparing: +```bash +python3 -m json.tool ${BASE_JSON} > /dev/null +python3 -m json.tool ${CONT_JSON} > /dev/null +``` -Pass benchmark options after the binaries so compare.py forwards them: +Run the comparison: ```bash -python3 ${COMPARE_PY} benchmarks \ - build/benchmarks-all_bench_${BASELINE}/benchmarks \ - build/benchmarks-all_bench_${CONTENDER}/benchmarks \ - --benchmark_filter="BM_Rank" +python3 ${COMPARE_PY} -a benchmarks ${BASE_JSON} ${CONT_JSON} ``` +### Optional: filter to reduce noise and runtime + +Pass filter when generating JSON files: +```bash +FILTER="BM_Rank" +build/benchmarks-all_bench_${BASELINE}/benchmarks --benchmark_filter="${FILTER}" --benchmark_report_aggregates_only=true --benchmark_display_aggregates_only=true ... +build/benchmarks-all_bench_${CONTENDER}/benchmarks --benchmark_filter="${FILTER}" --benchmark_report_aggregates_only=true --benchmark_display_aggregates_only=true ... +``` + +## Retry and Timeout Policy + +1. Run benchmarks sequentially; do not background with `nohup`/`&`. +2. If a run times out, narrow filter and retry once. +3. Maximum retries per benchmark group: 1. +4. If still failing, emit blocked/partial findings instead of repeated attempts. + ## Step 3 — Record findings Capture the compare.py output (terminal transcript or redirected file) and note any statistically significant regressions or wins. diff --git a/.kilo/skills/benchmarks/SKILL.md b/.kilo/skills/benchmarks/SKILL.md index 9d5a45e..354db03 100644 --- a/.kilo/skills/benchmarks/SKILL.md +++ b/.kilo/skills/benchmarks/SKILL.md @@ -25,6 +25,8 @@ BUILD_SUFFIX=local ## Step 1 — Build +If benchmarks affected by the changes are easily tractable build only related targets. + **Pure timing (benchmarks-all, Release):** ```bash cmake -B build/benchmarks-all_${BUILD_SUFFIX} -DCMAKE_BUILD_TYPE=Release -DPIXIE_BENCHMARKS=ON @@ -39,6 +41,13 @@ cmake --build build/benchmarks-diagnostic_${BUILD_SUFFIX} --config RelWithDebInf ## Step 2 — Run +Prefer running benchmarks with filtering passing the benchmarks that should be affected. + +Execution guardrails: +- Run benchmark commands sequentially in CI. +- Avoid background jobs (`nohup`, `&`) for benchmark collection. +- Always write machine-readable results with `--benchmark_out` when data is later parsed. + ### Available benchmark binaries | Binary | What it covers | @@ -103,10 +112,17 @@ build/benchmarks-diagnostic_${BUILD_SUFFIX}/RelWithDebInfo/benchmarks \ ```bash build/benchmarks-all_${BUILD_SUFFIX}/Release/benchmarks \ --benchmark_filter="${FILTER}" \ + --benchmark_report_aggregates_only=true \ + --benchmark_display_aggregates_only=true \ --benchmark_format=json \ --benchmark_out=results.json ``` +Validate output before consuming: +```bash +python3 -m json.tool results.json > /dev/null +``` + ## Step 3 — Profile with perf (Linux only) Use when hardware counters alone are not enough and you need a full call-graph profile for post-processing. @@ -157,3 +173,6 @@ perf script -F +pid > perf.data.txt 5. **Pin CPU frequency** before timing runs: `sudo cpupower frequency-set -g performance` 6. **Filter to reduce noise**: narrow the filter regex to the benchmark under investigation 7. **Save JSON output** when comparing before/after changes: use `--benchmark_out` and diff the files +8. **Fail fast on environment issues**: precheck Python deps used by compare tooling (`numpy`, `scipy`) +9. **Use explicit retry limits**: on timeout, narrow scope and retry once; avoid repeated full-suite attempts +10. **Preflight perf counters**: run a tiny counter-enabled benchmark first; if counters unavailable, skip counter workflow diff --git a/include/pixie/bitvector.h b/include/pixie/bitvector.h index b0d8e0d..b486736 100644 --- a/include/pixie/bitvector.h +++ b/include/pixie/bitvector.h @@ -1,785 +1,785 @@ -#pragma once - -#include -#include - -#include -#include -#include -#include -#include -#include - -#ifdef PIXIE_DIAGNOSTICS -#include -#endif - -namespace pixie { - -/** - * @brief Non-interleaved, non-owning bit vector with rank and select. - * - * - * @details - * This is a two-level rank/select index for a bit vector stored - * externally as - * 64-bit words. The layout follows ideas from: - * - * {1} - * "SPIDER: Improved Succinct Rank and Select Performance" - * Matthew D. Laws, - * Jocelyn Bliven, Kit Conklin, Elyes Laalai, Samuel McCauley, - * Zach S. - * Sturdevant - * https://github.com/williams-cs/spider - * - * {2} "Engineering - * compact data structures for rank and select queries on - * bit vectors" - * Kurpicz F. - * https://github.com/pasta-toolbox/bit_vector - * - * Structure - * overview: - * - Super blocks of 2^16 bits with 64-bit ranks (~0.98% - * overhead). - * - Basic blocks of 512 bits with 16-bit ranks (~3.125% - * overhead). - * - Select samples every 16384 bits (~0.39% overhead). - * - * - * Rank: 2 table lookups plus SIMD popcount in the 512-bit block. - * - * Select: - - * * - Start from a sampled super block. - * - SIMD linear scan to find the super - * block. - * - SIMD linear scan to find the basic block. - * - * This variant does - * not interleave data and index, favoring simpler scans. - */ -class BitVector { - private: - constexpr static size_t kWordSize = 64; - constexpr static size_t kSuperBlockRankIntSize = 64; - constexpr static size_t kBasicBlockRankIntSize = 16; - constexpr static size_t kBasicBlockSize = 512; - constexpr static size_t kWordsPerBlock = 8; - constexpr static size_t kSuperBlockSize = 65536; - constexpr static size_t kBlocksPerSuperBlock = 128; - constexpr static size_t kSelectSampleFrequency = 16384; - - alignas(64) uint64_t delta_super[8]; - alignas(64) uint16_t delta_basic[32]; - - AlignedStorage super_block_rank_; // 64-bit global prefix sums - AlignedStorage basic_block_rank_; // 16-bit local prefix sums - AlignedStorage select1_samples_; // 64-bit global positions - AlignedStorage select0_samples_; // 64-bit global positions - const size_t num_bits_; - const size_t padded_size_; - size_t max_rank_; - - std::span bits_; - - /** - * @brief Precompute rank for fast queries. - */ - void build_rank() { - size_t num_superblocks = 8 + (padded_size_ - 1) / kSuperBlockSize; - // Add more blocks to ease SIMD processing - // num_basicblocks to fully cover superblock, i.e. 128 - // This reduces branching in select - num_superblocks = ((num_superblocks + 7) / 8) * 8; - size_t num_basicblocks = num_superblocks * kBlocksPerSuperBlock; - super_block_rank_.resize(num_superblocks * 64); - basic_block_rank_.resize(num_basicblocks * 16); - - auto super_block_rank = super_block_rank_.As64BitInts(); - auto basic_block_rank = basic_block_rank_.As16BitInts(); - - uint64_t super_block_sum = 0; - uint16_t basic_block_sum = 0; - - for (size_t i = 0; i / kBasicBlockSize < basic_block_rank.size(); - i += kWordSize) { - if (i % kSuperBlockSize == 0) { - super_block_sum += basic_block_sum; - super_block_rank[i / kSuperBlockSize] = super_block_sum; - basic_block_sum = 0; - } - if (i % kBasicBlockSize == 0) { - basic_block_rank[i / kBasicBlockSize] = basic_block_sum; - } - if (i / kWordSize < bits_.size()) { - basic_block_sum += std::popcount(bits_[i / kWordSize]); - } - } - max_rank_ = super_block_sum + basic_block_sum; - } - - /** - * @brief Calculate select samples. - */ - void build_select() { - uint64_t milestone = kSelectSampleFrequency; - uint64_t milestone0 = kSelectSampleFrequency; - uint64_t rank = 0; - uint64_t rank0 = 0; - - size_t num_one_samples = - 1 + (max_rank_ + kSelectSampleFrequency - 1) / kSelectSampleFrequency; - size_t num_zero_samples = - 1 + (num_bits_ - max_rank_ + kSelectSampleFrequency - 1) / - kSelectSampleFrequency; - - select1_samples_.resize(num_one_samples * 64); - select0_samples_.resize(num_zero_samples * 64); - auto select1_samples = select1_samples_.As64BitInts(); - auto select0_samples = select0_samples_.As64BitInts(); - - select1_samples[0] = 0; - select0_samples[0] = 0; - - size_t num_zeros = 1, num_ones = 1; - - for (size_t i = 0; i < bits_.size(); ++i) { - auto ones = std::popcount(bits_[i]); - auto zeros = 64 - ones; - if (rank + ones >= milestone) { - auto pos = select_64(bits_[i], milestone - rank - 1); - // TODO: try including global rank into select samples to save - // a cache miss on global rank scan - select1_samples[num_ones++] = (64 * i + pos) / kSuperBlockSize; - milestone += kSelectSampleFrequency; - } - if (rank0 + zeros >= milestone0) { - auto pos = select_64(~bits_[i], milestone0 - rank0 - 1); - select0_samples[num_zeros++] = (64 * i + pos) / kSuperBlockSize; - milestone0 += kSelectSampleFrequency; - } - rank += ones; - rank0 += zeros; - } - - for (size_t i = 0; i < 8; ++i) { - delta_super[i] = i * kSuperBlockSize; - } - for (size_t i = 0; i < 32; ++i) { - delta_basic[i] = i * kBasicBlockSize; - } - } - - /** - * @brief First step of the select operation. - * @param rank 1-based - * rank of the 1-bit to locate. - */ - uint64_t find_superblock(uint64_t rank) const { - auto select1_samples = select1_samples_.AsConst64BitInts(); - auto super_block_rank = super_block_rank_.AsConst64BitInts(); - - uint64_t left = select1_samples[rank / kSelectSampleFrequency]; - - while (left + 7 < super_block_rank.size()) { - auto len = lower_bound_8x64(&super_block_rank[left], rank); - if (len < 8) { - return left + len - 1; - } - left += 8; - } - if (left + 3 < super_block_rank.size()) { - auto len = lower_bound_4x64(&super_block_rank[left], rank); - if (len < 4) { - return left + len - 1; - } - left += 4; - } - while (left < super_block_rank.size() && super_block_rank[left] < rank) { - left++; - } - return left - 1; - } - - /** - * @brief First step of the select0 operation. - * @param rank0 1-based - * rank of the 0-bit to locate. - */ - uint64_t find_superblock_zeros(uint64_t rank0) const { - auto select0_samples = select0_samples_.AsConst64BitInts(); - auto super_block_rank = super_block_rank_.AsConst64BitInts(); - - uint64_t left = select0_samples[rank0 / kSelectSampleFrequency]; - - while (left + 7 < super_block_rank.size()) { - auto len = lower_bound_delta_8x64(&super_block_rank[left], rank0, - delta_super, kSuperBlockSize * left); - if (len < 8) { - return left + len - 1; - } - left += 8; - } - if (left + 3 < super_block_rank.size()) { - auto len = lower_bound_delta_4x64(&super_block_rank[left], rank0, - delta_super, kSuperBlockSize * left); - if (len < 4) { - return left + len - 1; - } - left += 4; - } - while (left < super_block_rank.size() && - kSuperBlockSize * left - super_block_rank[left] < rank0) { - left++; - } - return left - 1; - } - - /** - * @brief SIMD-optimized linear scan. - * @param local_rank Rank within - * the super block. - * @param s_block Super block index. - * @details - * - * Processes 32 16-bit entries at once (full cache line), so there is at most - - * * 4 iterations. - */ - uint64_t find_basicblock(uint16_t local_rank, uint64_t s_block) const { - auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); - - for (size_t pos = 0; pos < kBlocksPerSuperBlock; pos += 32) { - auto count = lower_bound_32x16( - &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank); - if (count < 32) { - return kBlocksPerSuperBlock * s_block + pos + count - 1; - } - } - return kBlocksPerSuperBlock * s_block + kBlocksPerSuperBlock - 1; - } - - /** - * @brief SIMD-optimized linear scan. - * @param local_rank0 Rank of - * zeros within the super block. - * @param s_block Super block index. - * - * @details - * Processes 32 16-bit entries at once (full cache line), so - * there is at most - * 4 iterations. - */ - uint64_t find_basicblock_zeros(uint16_t local_rank0, uint64_t s_block) const { - auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); - for (size_t pos = 0; pos < kBlocksPerSuperBlock; pos += 32) { - auto count = lower_bound_delta_32x16( - &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0, - delta_basic, kBasicBlockSize * pos); - if (count < 32) { - return kBlocksPerSuperBlock * s_block + pos + count - 1; - } - } - return kBlocksPerSuperBlock * s_block + kBlocksPerSuperBlock - 1; - } - - /** - * @brief Interpolation search with SIMD optimization. - * @param - * local_rank Rank within the super block. - * @param s_block Super block - * index. - * @details - * Similar to find_basicblock but initial guess is - * based on linear - * interpolation, for random data it should make initial - * guess correct - * most of the times, we start from the 32 wide block with - * interpolation - * guess at the center, if we see that select result lie in - * lower blocks - * we backoff to find_basicblock - */ - uint64_t find_basicblock_is(uint16_t local_rank, uint64_t s_block) const { - auto super_block_rank = super_block_rank_.AsConst64BitInts(); - auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); - - auto lower = super_block_rank[s_block]; - auto upper = super_block_rank[s_block + 1]; - - uint64_t pos = kBlocksPerSuperBlock * local_rank / (upper - lower); - pos = pos + 16 < 32 ? 0 : (pos - 16); - pos = pos > 96 ? 96 : pos; - while (pos < 96) { - auto count = lower_bound_32x16( - &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank); - if (count == 0) { - return find_basicblock(local_rank, s_block); - } - if (count < 32) { - return kBlocksPerSuperBlock * s_block + pos + count - 1; - } - pos += 32; - } - pos = 96; - auto count = lower_bound_32x16( - &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank); - if (count == 0) { - return find_basicblock(local_rank, s_block); - } - return kBlocksPerSuperBlock * s_block + pos + count - 1; - } - - /** - * @brief Interpolation search with SIMD optimization. - * @param - * local_rank0 Rank of zeros within the super block. - * @param s_block Super - * block index. - * @details - * Similar to find_basicblock_zeros but - * initial guess is based on linear - * interpolation, for random data it - * should make initial guess correct - * most of the times, we start from the - * 32 wide block with interpolation - * guess at the center, if we see that - * select result lie in lower blocks - * we backoff to find_basicblock_zeros - - */ - uint64_t find_basicblock_is_zeros(uint16_t local_rank0, - uint64_t s_block) const { - auto super_block_rank = super_block_rank_.AsConst64BitInts(); - auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); - - auto lower = kSuperBlockSize * s_block - super_block_rank[s_block]; - auto upper = - kSuperBlockSize * (s_block + 1) - super_block_rank[s_block + 1]; - - uint64_t pos = kBlocksPerSuperBlock * local_rank0 / (upper - lower); - pos = pos + 16 < 32 ? 0 : (pos - 16); - pos = pos > 96 ? 96 : pos; - while (pos < 96) { - auto count = lower_bound_delta_32x16( - &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0, - delta_basic, kBasicBlockSize * pos); - if (count == 0) { - return find_basicblock_zeros(local_rank0, s_block); - } - if (count < 32) { - return kBlocksPerSuperBlock * s_block + pos + count - 1; - } - pos += 32; - } - pos = 96; - auto count = lower_bound_delta_32x16( - &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0, - delta_basic, kBasicBlockSize * pos); - if (count == 0) { - return find_basicblock_zeros(local_rank0, s_block); - } - return kBlocksPerSuperBlock * s_block + pos + count - 1; - } - - public: -#ifdef PIXIE_DIAGNOSTICS - struct DiagnosticsBytes { - size_t source_bitvector_bytes = 0; - size_t super_block_rank_bytes = 0; - size_t basic_block_rank_bytes = 0; - size_t select1_samples_bytes = 0; - size_t select0_samples_bytes = 0; - size_t total_bytes = 0; - }; - - /** - * @brief Returns the number of bytes used by each internal component. - */ - DiagnosticsBytes diagnostics_bytes() const { - DiagnosticsBytes result; - result.source_bitvector_bytes = (num_bits_ + 7) / 8; - result.super_block_rank_bytes = super_block_rank_.AsConstBytes().size(); - result.basic_block_rank_bytes = basic_block_rank_.AsConstBytes().size(); - result.select1_samples_bytes = select1_samples_.AsConstBytes().size(); - result.select0_samples_bytes = select0_samples_.AsConstBytes().size(); - result.total_bytes = - result.super_block_rank_bytes + result.basic_block_rank_bytes + - result.select1_samples_bytes + result.select0_samples_bytes; - return result; - } - - /** - * @brief Log memory usage of internal components. - */ - void memory_report() const { - const auto diagnostics = diagnostics_bytes(); - const double source_bytes = - static_cast(diagnostics.source_bitvector_bytes); - const auto log_bytes = [&](std::string_view label, size_t bytes) { - const double percentage = - source_bytes > 0.0 ? 100.0 * static_cast(bytes) / source_bytes - : 0.0; - spdlog::info("BitVector {}: {} bytes ({:.2f}% of source)", label, bytes, - percentage); - }; - log_bytes("source_bitvector", diagnostics.source_bitvector_bytes); - log_bytes("super_block_rank", diagnostics.super_block_rank_bytes); - log_bytes("basic_block_rank", diagnostics.basic_block_rank_bytes); - log_bytes("select1_samples", diagnostics.select1_samples_bytes); - log_bytes("select0_samples", diagnostics.select0_samples_bytes); - log_bytes("total", diagnostics.total_bytes); - } -#endif - /** - * @brief Construct from an external array of 64-bit words. - * @param - * bit_vector Backing data, not owned. - * @param num_bits Number of valid - * bits in the vector. - */ - explicit BitVector(std::span bit_vector, size_t num_bits) - : num_bits_(std::min(num_bits, bit_vector.size() * kWordSize)), - padded_size_(((num_bits_ + kWordSize - 1) / kWordSize) * kWordSize), - bits_(bit_vector) { - build_rank(); - build_select(); - } - - /** - * @brief Returns the number of valid bits. - */ - size_t size() const { return num_bits_; } - - /** - * @brief Returns the bit at the given position. - * @param pos Bit - * index in [0, size()). - */ - int operator[](size_t pos) const { - size_t word_idx = pos / kWordSize; - size_t bit_off = pos % kWordSize; - - return (bits_[word_idx] >> bit_off) & 1; - } - - /** - * @brief Rank of 1s up to position pos (exclusive). - * @param pos Bit - * index in [0, size()]. - * @return Number of 1s in [0, pos). - */ - uint64_t rank(size_t pos) const { - if (pos >= bits_.size() * kWordSize) [[unlikely]] { - return max_rank_; - } - - auto super_block_rank = super_block_rank_.AsConst64BitInts(); - auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); - - uint64_t b_block = pos / kBasicBlockSize; - uint64_t s_block = pos / kSuperBlockSize; - // Precomputed rank - uint64_t result = super_block_rank[s_block] + basic_block_rank[b_block]; - // Basic block tail - result += rank_512(&bits_[b_block * kWordsPerBlock], - pos - (b_block * kBasicBlockSize)); - return result; - } - - /** - * @brief Rank of 0s up to position pos (exclusive). - * @param pos Bit - * index in [0, size()]. - * @return Number of 0s in [0, pos). - */ - uint64_t rank0(size_t pos) const { - if (pos >= bits_.size() * kWordSize) [[unlikely]] { - return bits_.size() * kWordSize - max_rank_; - } - return pos - rank(pos); - } - - /** - * @brief Select the position of the rank-th 1-bit (1-indexed). - * - * @param rank 1-based rank of the 1-bit to select. - * @return Bit index, or - * size() if rank is out of range. - */ - uint64_t select(size_t rank) const { - if (rank > max_rank_) [[unlikely]] { - return num_bits_; - } - if (rank == 0) [[unlikely]] { - return 0; - } - auto super_block_rank = super_block_rank_.AsConst64BitInts(); - auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); - - uint64_t s_block = find_superblock(rank); - rank -= super_block_rank[s_block]; - auto pos = find_basicblock_is(rank, s_block); - rank -= basic_block_rank[pos]; - pos *= kWordsPerBlock; - - // Final search - if (pos + kWordsPerBlock - 1 < kWordsPerBlock) [[unlikely]] { - size_t ones = std::popcount(bits_[pos]); - while (pos < bits_.size() && ones < rank) { - rank -= ones; - ones = std::popcount(bits_[++pos]); - } - return kWordSize * pos + select_64(bits_[pos], rank - 1); - } - return kWordSize * pos + select_512(&bits_[pos], rank - 1); - } - - /** - * @brief Select the position of the rank0-th 0-bit (1-indexed). - * - * @param rank0 1-based rank of the 0-bit to select. - * @return Bit index, - * or size() if rank0 is out of range. - */ - uint64_t select0(size_t rank0) const { - if (rank0 > num_bits_ - max_rank_) [[unlikely]] { - return num_bits_; - } - if (rank0 == 0) [[unlikely]] { - return 0; - } - auto super_block_rank = super_block_rank_.AsConst64BitInts(); - auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); - - uint64_t s_block = find_superblock_zeros(rank0); - rank0 -= kSuperBlockSize * s_block - super_block_rank[s_block]; - auto pos = find_basicblock_is_zeros(rank0, s_block); - auto pos_in_super_block = pos & (kBlocksPerSuperBlock - 1); - rank0 -= kBasicBlockSize * pos_in_super_block - basic_block_rank[pos]; - pos *= kWordsPerBlock; - - // Final search - if (pos + kWordsPerBlock - 1 < kWordsPerBlock) [[unlikely]] { - size_t zeros = std::popcount(~bits_[pos]); - while (pos < bits_.size() && zeros < rank0) { - rank0 -= zeros; - zeros = std::popcount(~bits_[++pos]); - } - return kWordSize * pos + select_64(~bits_[pos], rank0 - 1); - } - return kWordSize * pos + select0_512(&bits_[pos], rank0 - 1); - } - - /** - * @brief Convert to a binary string (debug helper). - */ - std::string to_string() const { - std::string result; - result.reserve(num_bits_); - - for (size_t i = 0; i < num_bits_; i++) { - result.push_back(operator[](i) ? '1' : '0'); - } - - return result; - } -}; - -/** - * @brief Interleaved, owning bit vector with rank and select. - * - * - * @details - * This variant interleaves data with local rank metadata to reduce - * cache - * misses for rank queries. It copies input bits into an interleaved - * layout. - * - * Based on: - * "SPIDER: Improved Succinct Rank and Select - * Performance" - * Matthew D. Laws, Jocelyn Bliven, Kit Conklin, Elyes Laalai, - * Samuel McCauley, - * Zach S. Sturdevant - */ -class BitVectorInterleaved { - private: - constexpr static size_t kWordSize = 64; - constexpr static size_t kSuperBlockRankIntSize = 64; - constexpr static size_t kBasicBlockRankIntSize = 16; - /** - * 496 bits data + 16 bit local rank - */ - constexpr static size_t kBasicBlockSize = 496; - /** - * 63488 = 496 * 128, so position of superblock can be obtained - * from the position of basic block by dividing on 128 or - * right shift on 7 bits which is cheaper then performing another - * division. - */ - constexpr static size_t kSuperBlockSize = 63488; - constexpr static size_t kBlocksPerSuperBlock = 128; - constexpr static size_t kWordsPerBlock = 8; - - const size_t num_bits_; - std::vector bits_interleaved; - std::vector super_block_rank_; - - class BitReader { - size_t iterator_64_ = 0; - size_t offset_size_ = 0; - size_t offset_bits_ = 0; - std::span bits_; - - public: - BitReader(std::span bits) : bits_(bits) {} - uint64_t ReadBits64(size_t num_bits) { - if (num_bits > 64) { - num_bits = 64; - } - uint64_t result = offset_bits_ & first_bits_mask(num_bits); - if (offset_size_ >= num_bits) { - offset_bits_ >>= num_bits; - offset_size_ -= num_bits; - return result; - } - uint64_t next = iterator_64_ < bits_.size() ? bits_[iterator_64_++] : 0; - result ^= (next & first_bits_mask(num_bits - offset_size_)) - << offset_size_; - offset_bits_ = (num_bits - offset_size_ == 64) - ? 0 - : next >> (num_bits - offset_size_); - offset_size_ = 64 - (num_bits - offset_size_); - return result; - } - }; - - public: - /** - * @brief Construct from an external array of 64-bit words. - * @param - * bit_vector Backing data to copy and interleave. - * @param num_bits Number - * of valid bits in the vector. - */ - explicit BitVectorInterleaved(std::span bit_vector, - size_t num_bits) - : num_bits_(std::min(num_bits, bit_vector.size() * kWordSize)) { - build_rank_interleaved(bit_vector, num_bits); - } - - /** - * @brief Mask with the lowest num bits set. - */ - static inline uint64_t first_bits_mask(size_t num) { - return num >= 64 ? UINT64_MAX : ((1llu << num) - 1); - } - - /** - * @brief Returns the number of valid bits. - */ - size_t size() const { return num_bits_; } - - /** - * @brief Returns the bit at the given position. - * @param pos Bit - * index in [0, size()). - */ - int operator[](size_t pos) const { - size_t block_id = pos / kBasicBlockSize; - size_t block_bit = pos - block_id * kBasicBlockSize; - size_t word_id = block_id * kWordsPerBlock + block_bit / kWordSize; - size_t word_bit = block_bit % kWordSize; - kWordSize; - - return (bits_interleaved[word_id] >> word_bit) & 1; - } - - /** - * @brief Build the interleaved layout and rank index. - * @param bits - * Source bit vector as 64-bit words. - * @param num_bits Number of valid - * bits in the source. - */ - void build_rank_interleaved(std::span bits, size_t num_bits) { - size_t num_superblocks = 1 + (num_bits_ - 1) / kSuperBlockSize; - super_block_rank_.resize(num_superblocks); - size_t num_basicblocks = 1 + (num_bits_ - 1) / kBasicBlockSize; - bits_interleaved.resize(num_basicblocks * (512 / kWordSize)); - - uint64_t super_block_sum = 0; - uint16_t basic_block_sum = 0; - auto bit_reader = BitReader(bits); - - for (size_t i = 0; i * kBasicBlockSize < num_bits; ++i) { - if (i % (kSuperBlockSize / kBasicBlockSize) == 0) { - super_block_sum += basic_block_sum; - super_block_rank_[i / (kSuperBlockSize / kBasicBlockSize)] = - super_block_sum; - basic_block_sum = 0; - } - bits_interleaved[i * (kWordsPerBlock) + 7] = - static_cast(basic_block_sum) << 48; - - for (size_t j = 0; j < 7 && kWordSize * (i + j) < num_bits; ++j) { - bits_interleaved[i * (kWordsPerBlock) + j] = - bit_reader.ReadBits64(std::min( - 64ull, num_bits - i * kBasicBlockSize + j * kWordSize)); - basic_block_sum += - std::popcount(bits_interleaved[i * (kWordsPerBlock) + j]); - } - if ((i + 7) * kWordSize < num_bits) { - auto v = bit_reader.ReadBits64(std::min( - 48ull, num_bits - (i * kBasicBlockSize + 7 * kWordSize))); - bits_interleaved[i * (kWordsPerBlock) + 7] ^= v; - basic_block_sum += std::popcount(v); - } - } - } - - /** - * @brief Rank of 1s up to position pos (exclusive). - * @param pos Bit - * index in [0, size()]. - * @return Number of 1s in [0, pos). - */ - uint64_t rank(size_t pos) const { - // Multiplication/devisions - uint64_t b_block = pos / kBasicBlockSize; - uint64_t s_block = b_block / kBlocksPerSuperBlock; - uint64_t b_block_pos = b_block * kWordsPerBlock; - // Super block rank - uint64_t result = super_block_rank_[s_block]; - /** - * Ok, so here's quite the important factor to load 512-bit region - * at &bits_interleaved[b_block_pos], we store local rank as 16 last - * bits of it. Prefetch should guarantee but seems like there is no - * need for it. - */ - // __builtin_prefetch(&bits_interleaved[b_block_pos]); - result += rank_512(&bits_interleaved[b_block_pos], - pos - (b_block * kBasicBlockSize)); - result += bits_interleaved[b_block_pos + 7] >> 48; - return result; - } - - /** - * @brief Convert to a binary string (debug helper). - */ - std::string to_string() const { - std::string result; - result.reserve(num_bits_); - - for (size_t i = 0; i < num_bits_; i++) { - result.push_back(operator[](i) ? '1' : '0'); - } - - return result; - } -}; - -} // namespace pixie +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +#ifdef PIXIE_DIAGNOSTICS +#include +#endif + +namespace pixie { + +/** + * @brief Non-interleaved, non-owning bit vector with rank and select. + * + * + * @details + * This is a two-level rank/select index for a bit vector stored + * externally as + * 64-bit words. The layout follows ideas from: + * + * {1} + * "SPIDER: Improved Succinct Rank and Select Performance" + * Matthew D. Laws, + * Jocelyn Bliven, Kit Conklin, Elyes Laalai, Samuel McCauley, + * Zach S. + * Sturdevant + * https://github.com/williams-cs/spider + * + * {2} "Engineering + * compact data structures for rank and select queries on + * bit vectors" + * Kurpicz F. + * https://github.com/pasta-toolbox/bit_vector + * + * Structure + * overview: + * - Super blocks of 2^16 bits with 64-bit ranks (~0.98% + * overhead). + * - Basic blocks of 512 bits with 16-bit ranks (~3.125% + * overhead). + * - Select samples every 16384 bits (~0.39% overhead). + * + * + * Rank: 2 table lookups plus SIMD popcount in the 512-bit block. + * + * Select: + + * * - Start from a sampled super block. + * - SIMD linear scan to find the super + * block. + * - SIMD linear scan to find the basic block. + * + * This variant does + * not interleave data and index, favoring simpler scans. + */ +class BitVector { + private: + constexpr static size_t kWordSize = 64; + constexpr static size_t kSuperBlockRankIntSize = 64; + constexpr static size_t kBasicBlockRankIntSize = 16; + constexpr static size_t kBasicBlockSize = 512; + constexpr static size_t kWordsPerBlock = 8; + constexpr static size_t kSuperBlockSize = 65536; + constexpr static size_t kBlocksPerSuperBlock = 128; + constexpr static size_t kSelectSampleFrequency = 16384; + + alignas(64) uint64_t delta_super[8]; + alignas(64) uint16_t delta_basic[32]; + + AlignedStorage super_block_rank_; // 64-bit global prefix sums + AlignedStorage basic_block_rank_; // 16-bit local prefix sums + AlignedStorage select1_samples_; // 64-bit global positions + AlignedStorage select0_samples_; // 64-bit global positions + const size_t num_bits_; + const size_t padded_size_; + size_t max_rank_; + + std::span bits_; + + /** + * @brief Precompute rank for fast queries. + */ + void build_rank() { + size_t num_superblocks = 8 + (padded_size_ - 1) / kSuperBlockSize; + // Add more blocks to ease SIMD processing + // num_basicblocks to fully cover superblock, i.e. 128 + // This reduces branching in select + num_superblocks = ((num_superblocks + 7) / 8) * 8; + size_t num_basicblocks = num_superblocks * kBlocksPerSuperBlock; + super_block_rank_.resize(num_superblocks * 64); + basic_block_rank_.resize(num_basicblocks * 16); + + auto super_block_rank = super_block_rank_.As64BitInts(); + auto basic_block_rank = basic_block_rank_.As16BitInts(); + + uint64_t super_block_sum = 0; + uint16_t basic_block_sum = 0; + + for (size_t i = 0; i / kBasicBlockSize < basic_block_rank.size(); + i += kWordSize) { + if (i % kSuperBlockSize == 0) { + super_block_sum += basic_block_sum; + super_block_rank[i / kSuperBlockSize] = super_block_sum; + basic_block_sum = 0; + } + if (i % kBasicBlockSize == 0) { + basic_block_rank[i / kBasicBlockSize] = basic_block_sum; + } + if (i / kWordSize < bits_.size()) { + basic_block_sum += std::popcount(bits_[i / kWordSize]); + } + } + max_rank_ = super_block_sum + basic_block_sum; + } + + /** + * @brief Calculate select samples. + */ + void build_select() { + uint64_t milestone = kSelectSampleFrequency; + uint64_t milestone0 = kSelectSampleFrequency; + uint64_t rank = 0; + uint64_t rank0 = 0; + + size_t num_one_samples = + 1 + (max_rank_ + kSelectSampleFrequency - 1) / kSelectSampleFrequency; + size_t num_zero_samples = + 1 + (num_bits_ - max_rank_ + kSelectSampleFrequency - 1) / + kSelectSampleFrequency; + + select1_samples_.resize(num_one_samples * 64); + select0_samples_.resize(num_zero_samples * 64); + auto select1_samples = select1_samples_.As64BitInts(); + auto select0_samples = select0_samples_.As64BitInts(); + + select1_samples[0] = 0; + select0_samples[0] = 0; + + size_t num_zeros = 1, num_ones = 1; + + for (size_t i = 0; i < bits_.size(); ++i) { + auto ones = std::popcount(bits_[i]); + auto zeros = 64 - ones; + if (rank + ones >= milestone) { + auto pos = select_64(bits_[i], milestone - rank - 1); + // TODO: try including global rank into select samples to save + // a cache miss on global rank scan + select1_samples[num_ones++] = (64 * i + pos) / kSuperBlockSize; + milestone += kSelectSampleFrequency; + } + if (rank0 + zeros >= milestone0) { + auto pos = select_64(~bits_[i], milestone0 - rank0 - 1); + select0_samples[num_zeros++] = (64 * i + pos) / kSuperBlockSize; + milestone0 += kSelectSampleFrequency; + } + rank += ones; + rank0 += zeros; + } + + for (size_t i = 0; i < 8; ++i) { + delta_super[i] = i * kSuperBlockSize; + } + for (size_t i = 0; i < 32; ++i) { + delta_basic[i] = i * kBasicBlockSize; + } + } + + /** + * @brief First step of the select operation. + * @param rank 1-based + * rank of the 1-bit to locate. + */ + uint64_t find_superblock(uint64_t rank) const { + auto select1_samples = select1_samples_.AsConst64BitInts(); + auto super_block_rank = super_block_rank_.AsConst64BitInts(); + + uint64_t left = select1_samples[rank / kSelectSampleFrequency]; + + while (left + 7 < super_block_rank.size()) { + auto len = lower_bound_8x64(&super_block_rank[left], rank); + if (len < 8) { + return left + len - 1; + } + left += 8; + } + if (left + 3 < super_block_rank.size()) { + auto len = lower_bound_4x64(&super_block_rank[left], rank); + if (len < 4) { + return left + len - 1; + } + left += 4; + } + while (left < super_block_rank.size() && super_block_rank[left] < rank) { + left++; + } + return left - 1; + } + + /** + * @brief First step of the select0 operation. + * @param rank0 1-based + * rank of the 0-bit to locate. + */ + uint64_t find_superblock_zeros(uint64_t rank0) const { + auto select0_samples = select0_samples_.AsConst64BitInts(); + auto super_block_rank = super_block_rank_.AsConst64BitInts(); + + uint64_t left = select0_samples[rank0 / kSelectSampleFrequency]; + + while (left + 7 < super_block_rank.size()) { + auto len = lower_bound_delta_8x64(&super_block_rank[left], rank0, + delta_super, kSuperBlockSize * left); + if (len < 8) { + return left + len - 1; + } + left += 8; + } + if (left + 3 < super_block_rank.size()) { + auto len = lower_bound_delta_4x64(&super_block_rank[left], rank0, + delta_super, kSuperBlockSize * left); + if (len < 4) { + return left + len - 1; + } + left += 4; + } + while (left < super_block_rank.size() && + kSuperBlockSize * left - super_block_rank[left] < rank0) { + left++; + } + return left - 1; + } + + /** + * @brief SIMD-optimized linear scan. + * @param local_rank Rank within + * the super block. + * @param s_block Super block index. + * @details + * + * Processes 32 16-bit entries at once (full cache line), so there is at most + + * * 4 iterations. + */ + uint64_t find_basicblock(uint16_t local_rank, uint64_t s_block) const { + auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); + + for (size_t pos = 0; pos < kBlocksPerSuperBlock; pos += 32) { + auto count = lower_bound_32x16( + &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank); + if (count < 32) { + return kBlocksPerSuperBlock * s_block + pos + count - 1; + } + } + return kBlocksPerSuperBlock * s_block + kBlocksPerSuperBlock - 1; + } + + /** + * @brief SIMD-optimized linear scan. + * @param local_rank0 Rank of + * zeros within the super block. + * @param s_block Super block index. + * + * @details + * Processes 32 16-bit entries at once (full cache line), so + * there is at most + * 4 iterations. + */ + uint64_t find_basicblock_zeros(uint16_t local_rank0, uint64_t s_block) const { + auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); + for (size_t pos = 0; pos < kBlocksPerSuperBlock; pos += 32) { + auto count = lower_bound_delta_32x16( + &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0, + delta_basic, kBasicBlockSize * pos); + if (count < 32) { + return kBlocksPerSuperBlock * s_block + pos + count - 1; + } + } + return kBlocksPerSuperBlock * s_block + kBlocksPerSuperBlock - 1; + } + + /** + * @brief Interpolation search with SIMD optimization. + * @param + * local_rank Rank within the super block. + * @param s_block Super block + * index. + * @details + * Similar to find_basicblock but initial guess is + * based on linear + * interpolation, for random data it should make initial + * guess correct + * most of the times, we start from the 32 wide block with + * interpolation + * guess at the center, if we see that select result lie in + * lower blocks + * we backoff to find_basicblock + */ + uint64_t find_basicblock_is(uint16_t local_rank, uint64_t s_block) const { + auto super_block_rank = super_block_rank_.AsConst64BitInts(); + auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); + + auto lower = super_block_rank[s_block]; + auto upper = super_block_rank[s_block + 1]; + + uint64_t pos = kBlocksPerSuperBlock * local_rank / (upper - lower); + pos = pos + 16 < 32 ? 0 : (pos - 16); + pos = pos > 96 ? 96 : pos; + while (pos < 96) { + auto count = lower_bound_32x16( + &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank); + if (count == 0) { + return find_basicblock(local_rank, s_block); + } + if (count < 32) { + return kBlocksPerSuperBlock * s_block + pos + count - 1; + } + pos += 32; + } + pos = 96; + auto count = lower_bound_32x16( + &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank); + if (count == 0) { + return find_basicblock(local_rank, s_block); + } + return kBlocksPerSuperBlock * s_block + pos + count - 1; + } + + /** + * @brief Interpolation search with SIMD optimization. + * @param + * local_rank0 Rank of zeros within the super block. + * @param s_block Super + * block index. + * @details + * Similar to find_basicblock_zeros but + * initial guess is based on linear + * interpolation, for random data it + * should make initial guess correct + * most of the times, we start from the + * 32 wide block with interpolation + * guess at the center, if we see that + * select result lie in lower blocks + * we backoff to find_basicblock_zeros + + */ + uint64_t find_basicblock_is_zeros(uint16_t local_rank0, + uint64_t s_block) const { + auto super_block_rank = super_block_rank_.AsConst64BitInts(); + auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); + + auto lower = kSuperBlockSize * s_block - super_block_rank[s_block]; + auto upper = + kSuperBlockSize * (s_block + 1) - super_block_rank[s_block + 1]; + + uint64_t pos = kBlocksPerSuperBlock * local_rank0 / (upper - lower); + pos = pos + 16 < 32 ? 0 : (pos - 16); + pos = pos > 96 ? 96 : pos; + while (pos < 96) { + auto count = lower_bound_delta_32x16( + &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0, + delta_basic, kBasicBlockSize * pos); + if (count == 0) { + return find_basicblock_zeros(local_rank0, s_block); + } + if (count < 32) { + return kBlocksPerSuperBlock * s_block + pos + count - 1; + } + pos += 32; + } + pos = 96; + auto count = lower_bound_delta_32x16( + &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0, + delta_basic, kBasicBlockSize * pos); + if (count == 0) { + return find_basicblock_zeros(local_rank0, s_block); + } + return kBlocksPerSuperBlock * s_block + pos + count - 1; + } + + public: +#ifdef PIXIE_DIAGNOSTICS + struct DiagnosticsBytes { + size_t source_bitvector_bytes = 0; + size_t super_block_rank_bytes = 0; + size_t basic_block_rank_bytes = 0; + size_t select1_samples_bytes = 0; + size_t select0_samples_bytes = 0; + size_t total_bytes = 0; + }; + + /** + * @brief Returns the number of bytes used by each internal component. + */ + DiagnosticsBytes diagnostics_bytes() const { + DiagnosticsBytes result; + result.source_bitvector_bytes = (num_bits_ + 7) / 8; + result.super_block_rank_bytes = super_block_rank_.AsConstBytes().size(); + result.basic_block_rank_bytes = basic_block_rank_.AsConstBytes().size(); + result.select1_samples_bytes = select1_samples_.AsConstBytes().size(); + result.select0_samples_bytes = select0_samples_.AsConstBytes().size(); + result.total_bytes = + result.super_block_rank_bytes + result.basic_block_rank_bytes + + result.select1_samples_bytes + result.select0_samples_bytes; + return result; + } + + /** + * @brief Log memory usage of internal components. + */ + void memory_report() const { + const auto diagnostics = diagnostics_bytes(); + const double source_bytes = + static_cast(diagnostics.source_bitvector_bytes); + const auto log_bytes = [&](std::string_view label, size_t bytes) { + const double percentage = + source_bytes > 0.0 ? 100.0 * static_cast(bytes) / source_bytes + : 0.0; + spdlog::info("BitVector {}: {} bytes ({:.2f}% of source)", label, bytes, + percentage); + }; + log_bytes("source_bitvector", diagnostics.source_bitvector_bytes); + log_bytes("super_block_rank", diagnostics.super_block_rank_bytes); + log_bytes("basic_block_rank", diagnostics.basic_block_rank_bytes); + log_bytes("select1_samples", diagnostics.select1_samples_bytes); + log_bytes("select0_samples", diagnostics.select0_samples_bytes); + log_bytes("total", diagnostics.total_bytes); + } +#endif + /** + * @brief Construct from an external array of 64-bit words. + * @param + * bit_vector Backing data, not owned. + * @param num_bits Number of valid + * bits in the vector. + */ + explicit BitVector(std::span bit_vector, size_t num_bits) + : num_bits_(std::min(num_bits, bit_vector.size() * kWordSize)), + padded_size_(((num_bits_ + kWordSize - 1) / kWordSize) * kWordSize), + bits_(bit_vector) { + build_rank(); + build_select(); + } + + /** + * @brief Returns the number of valid bits. + */ + size_t size() const { return num_bits_; } + + /** + * @brief Returns the bit at the given position. + * @param pos Bit + * index in [0, size()). + */ + int operator[](size_t pos) const { + size_t word_idx = pos / kWordSize; + size_t bit_off = pos % kWordSize; + + return (bits_[word_idx] >> bit_off) & 1; + } + + /** + * @brief Rank of 1s up to position pos (exclusive). + * @param pos Bit + * index in [0, size()]. + * @return Number of 1s in [0, pos). + */ + uint64_t rank(size_t pos) const { + if (pos >= bits_.size() * kWordSize) [[unlikely]] { + return max_rank_; + } + + auto super_block_rank = super_block_rank_.AsConst64BitInts(); + auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); + + uint64_t b_block = pos / kBasicBlockSize; + uint64_t s_block = pos / kSuperBlockSize; + // Precomputed rank + uint64_t result = super_block_rank[s_block] + basic_block_rank[b_block]; + // Basic block tail + result += rank_512(&bits_[b_block * kWordsPerBlock], + pos - (b_block * kBasicBlockSize)); + return result; + } + + /** + * @brief Rank of 0s up to position pos (exclusive). + * @param pos Bit + * index in [0, size()]. + * @return Number of 0s in [0, pos). + */ + uint64_t rank0(size_t pos) const { + if (pos >= bits_.size() * kWordSize) [[unlikely]] { + return bits_.size() * kWordSize - max_rank_; + } + return pos - rank(pos); + } + + /** + * @brief Select the position of the rank-th 1-bit (1-indexed). + * + * @param rank 1-based rank of the 1-bit to select. + * @return Bit index, or + * size() if rank is out of range. + */ + uint64_t select(size_t rank) const { + if (rank > max_rank_) [[unlikely]] { + return num_bits_; + } + if (rank == 0) [[unlikely]] { + return 0; + } + auto super_block_rank = super_block_rank_.AsConst64BitInts(); + auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); + + uint64_t s_block = find_superblock(rank); + rank -= super_block_rank[s_block]; + auto pos = find_basicblock_is(rank, s_block); + rank -= basic_block_rank[pos]; + pos *= kWordsPerBlock; + + // Final search + if (pos + kWordsPerBlock - 1 < kWordsPerBlock) [[unlikely]] { + size_t ones = std::popcount(bits_[pos]); + while (pos < bits_.size() && ones < rank) { + rank -= ones; + ones = std::popcount(bits_[++pos]); + } + return kWordSize * pos + select_64(bits_[pos], rank - 1); + } + return kWordSize * pos + select_512(&bits_[pos], rank - 1); + } + + /** + * @brief Select the position of the rank0-th 0-bit (1-indexed). + * + * @param rank0 1-based rank of the 0-bit to select. + * @return Bit index, + * or size() if rank0 is out of range. + */ + uint64_t select0(size_t rank0) const { + if (rank0 > num_bits_ - max_rank_) [[unlikely]] { + return num_bits_; + } + if (rank0 == 0) [[unlikely]] { + return 0; + } + auto super_block_rank = super_block_rank_.AsConst64BitInts(); + auto basic_block_rank = basic_block_rank_.AsConst16BitInts(); + + uint64_t s_block = find_superblock_zeros(rank0); + rank0 -= kSuperBlockSize * s_block - super_block_rank[s_block]; + auto pos = find_basicblock_is_zeros(rank0, s_block); + auto pos_in_super_block = pos & (kBlocksPerSuperBlock - 1); + rank0 -= kBasicBlockSize * pos_in_super_block - basic_block_rank[pos]; + pos *= kWordsPerBlock; + + // Final search + if (pos + kWordsPerBlock - 1 < kWordsPerBlock) [[unlikely]] { + size_t zeros = std::popcount(~bits_[pos]); + while (pos < bits_.size() && zeros < rank0) { + rank0 -= zeros; + zeros = std::popcount(~bits_[++pos]); + } + return kWordSize * pos + select_64(~bits_[pos], rank0 - 1); + } + return kWordSize * pos + select0_512(&bits_[pos], rank0 - 1); + } + + /** + * @brief Convert to a binary string (debug helper). + */ + std::string to_string() const { + std::string result; + result.reserve(num_bits_); + + for (size_t i = 0; i < num_bits_; i++) { + result.push_back(operator[](i) ? '1' : '0'); + } + + return result; + } +}; + +/** + * @brief Interleaved, owning bit vector with rank and select. + * + * + * @details + * This variant interleaves data with local rank metadata to reduce + * cache + * misses for rank queries. It copies input bits into an interleaved + * layout. + * + * Based on: + * "SPIDER: Improved Succinct Rank and Select + * Performance" + * Matthew D. Laws, Jocelyn Bliven, Kit Conklin, Elyes Laalai, + * Samuel McCauley, + * Zach S. Sturdevant + */ +class BitVectorInterleaved { + private: + constexpr static size_t kWordSize = 64; + constexpr static size_t kSuperBlockRankIntSize = 64; + constexpr static size_t kBasicBlockRankIntSize = 16; + /** + * 496 bits data + 16 bit local rank + */ + constexpr static size_t kBasicBlockSize = 496; + /** + * 63488 = 496 * 128, so position of superblock can be obtained + * from the position of basic block by dividing on 128 or + * right shift on 7 bits which is cheaper then performing another + * division. + */ + constexpr static size_t kSuperBlockSize = 63488; + constexpr static size_t kBlocksPerSuperBlock = 128; + constexpr static size_t kWordsPerBlock = 8; + + const size_t num_bits_; + std::vector bits_interleaved; + std::vector super_block_rank_; + + class BitReader { + size_t iterator_64_ = 0; + size_t offset_size_ = 0; + size_t offset_bits_ = 0; + std::span bits_; + + public: + BitReader(std::span bits) : bits_(bits) {} + uint64_t ReadBits64(size_t num_bits) { + if (num_bits > 64) { + num_bits = 64; + } + uint64_t result = offset_bits_ & first_bits_mask(num_bits); + if (offset_size_ >= num_bits) { + offset_bits_ >>= num_bits; + offset_size_ -= num_bits; + return result; + } + uint64_t next = iterator_64_ < bits_.size() ? bits_[iterator_64_++] : 0; + result ^= (next & first_bits_mask(num_bits - offset_size_)) + << offset_size_; + offset_bits_ = (num_bits - offset_size_ == 64) + ? 0 + : next >> (num_bits - offset_size_); + offset_size_ = 64 - (num_bits - offset_size_); + return result; + } + }; + + public: + /** + * @brief Construct from an external array of 64-bit words. + * @param + * bit_vector Backing data to copy and interleave. + * @param num_bits Number + * of valid bits in the vector. + */ + explicit BitVectorInterleaved(std::span bit_vector, + size_t num_bits) + : num_bits_(std::min(num_bits, bit_vector.size() * kWordSize)) { + build_rank_interleaved(bit_vector, num_bits); + } + + /** + * @brief Mask with the lowest num bits set. + */ + static inline uint64_t first_bits_mask(size_t num) { + return num >= 64 ? UINT64_MAX : ((1llu << num) - 1); + } + + /** + * @brief Returns the number of valid bits. + */ + size_t size() const { return num_bits_; } + + /** + * @brief Returns the bit at the given position. + * @param pos Bit + * index in [0, size()). + */ + int operator[](size_t pos) const { + size_t block_id = pos / kBasicBlockSize; + size_t block_bit = pos - block_id * kBasicBlockSize; + size_t word_id = block_id * kWordsPerBlock + block_bit / kWordSize; + size_t word_bit = block_bit % kWordSize; + kWordSize; + + return (bits_interleaved[word_id] >> word_bit) & 1; + } + + /** + * @brief Build the interleaved layout and rank index. + * @param bits + * Source bit vector as 64-bit words. + * @param num_bits Number of valid + * bits in the source. + */ + void build_rank_interleaved(std::span bits, size_t num_bits) { + size_t num_superblocks = 1 + (num_bits_ - 1) / kSuperBlockSize; + super_block_rank_.resize(num_superblocks); + size_t num_basicblocks = 1 + (num_bits_ - 1) / kBasicBlockSize; + bits_interleaved.resize(num_basicblocks * (512 / kWordSize)); + + uint64_t super_block_sum = 0; + uint16_t basic_block_sum = 0; + auto bit_reader = BitReader(bits); + + for (size_t i = 0; i * kBasicBlockSize < num_bits; ++i) { + if (i % (kSuperBlockSize / kBasicBlockSize) == 0) { + super_block_sum += basic_block_sum; + super_block_rank_[i / (kSuperBlockSize / kBasicBlockSize)] = + super_block_sum; + basic_block_sum = 0; + } + bits_interleaved[i * (kWordsPerBlock) + 7] = + static_cast(basic_block_sum) << 48; + + for (size_t j = 0; j < 7 && kWordSize * (i + j) < num_bits; ++j) { + bits_interleaved[i * (kWordsPerBlock) + j] = + bit_reader.ReadBits64(std::min( + 64ull, num_bits - i * kBasicBlockSize + j * kWordSize)); + basic_block_sum += + std::popcount(bits_interleaved[i * (kWordsPerBlock) + j]); + } + if ((i + 7) * kWordSize < num_bits) { + auto v = bit_reader.ReadBits64(std::min( + 48ull, num_bits - (i * kBasicBlockSize + 7 * kWordSize))); + bits_interleaved[i * (kWordsPerBlock) + 7] ^= v; + basic_block_sum += std::popcount(v); + } + } + } + + /** + * @brief Rank of 1s up to position pos (exclusive). + * @param pos Bit + * index in [0, size()]. + * @return Number of 1s in [0, pos). + */ + uint64_t rank(size_t pos) const { + // Multiplication/devisions + uint64_t b_block = pos / kBasicBlockSize; + uint64_t s_block = b_block / kBlocksPerSuperBlock; + uint64_t b_block_pos = b_block * kWordsPerBlock; + // Super block rank + uint64_t result = super_block_rank_[s_block]; + /** + * Ok, so here's quite the important factor to load 512-bit region + * at &bits_interleaved[b_block_pos], we store local rank as 16 last + * bits of it. Prefetch should guarantee but seems like there is no + * need for it. + */ + // __builtin_prefetch(&bits_interleaved[b_block_pos]); + result += rank_512(&bits_interleaved[b_block_pos], + pos - (b_block * kBasicBlockSize)); + result += bits_interleaved[b_block_pos + 7] >> 48; + return result; + } + + /** + * @brief Convert to a binary string (debug helper). + */ + std::string to_string() const { + std::string result; + result.reserve(num_bits_); + + for (size_t i = 0; i < num_bits_; i++) { + result.push_back(operator[](i) ? '1' : '0'); + } + + return result; + } +}; + +} // namespace pixie diff --git a/scripts/draw_bp_representation.py b/scripts/draw_bp_representation.py new file mode 100644 index 0000000..ff22e0c --- /dev/null +++ b/scripts/draw_bp_representation.py @@ -0,0 +1,625 @@ +#!/usr/bin/env python3 +"""Draw standalone succinct tree encodings with optional token highlighting. + +Supported modes: +- BP: balanced parentheses in DFS order, one parenthesis pair per node. +- LOUDS: level-order unary degree sequence, accepted as bits (1/0) or + equivalent parentheses (() form where '(' = 1 and ')' = 0). +- DFUDS: depth-first unary degree sequence over parentheses with a single leading + sentinel '('. Each node with d children contributes d opening parentheses + followed by one closing parenthesis. The shown interval for each node spans + its entire d+1-token block, anchored at the block's first token. + +Examples: + uv run --no-project --with pydot python scripts/draw_bp_representation.py \ + "((()()())(()()))" --mode bp --output report/bp_representation.png + uv run --no-project --with pydot python scripts/draw_bp_representation.py \ + "((()((()(())))))" --mode louds --output report/louds_representation.svg + uv run --no-project --with pydot python scripts/draw_bp_representation.py \ + "((()((())))(()))" --mode dfuds --output report/dfuds_representation.svg +""" + +from __future__ import annotations + +import argparse +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path + +import pydot + + +@dataclass +class TreeNode: + node_id: int + parent_id: int | None + depth: int + children: list[int] = field(default_factory=list) + + +@dataclass +class EncodedTree: + mode: str + sequence: str + nodes: list[TreeNode] + node_token_positions: dict[int, int] + interval_targets: dict[int, int] = field(default_factory=dict) + + +def _normalize_parens(text: str) -> str: + filtered = "".join(ch for ch in text if not ch.isspace()) + invalid = sorted({ch for ch in filtered if ch not in "()"}) + if invalid: + chars = ", ".join(repr(ch) for ch in invalid) + raise ValueError(f"sequence contains invalid parenthesis characters: {chars}") + if not filtered: + raise ValueError("sequence is empty") + return filtered + + +def _normalize_louds(text: str) -> tuple[str, str]: + filtered = "".join(ch for ch in text if not ch.isspace()) + if not filtered: + raise ValueError("sequence is empty") + + if all(ch in "01" for ch in filtered): + return filtered, filtered + if all(ch in "()" for ch in filtered): + bits = "".join("1" if ch == "(" else "0" for ch in filtered) + return bits, filtered + + invalid = sorted({ch for ch in filtered if ch not in "01()"}) + if invalid: + chars = ", ".join(repr(ch) for ch in invalid) + raise ValueError(f"LOUDS sequence contains invalid characters: {chars}") + raise ValueError("LOUDS sequence must use either only '0/1' or only '(' and ')' tokens") + + +def _validate_highlight_index(length: int, highlight_index: int | None) -> None: + if highlight_index is None: + return + if highlight_index < 0 or highlight_index >= length: + raise ValueError( + f"highlight index {highlight_index} out of range [0, {length - 1}]" + ) + + +def _make_node(parent_id: int | None, depth: int, nodes: list[TreeNode]) -> int: + node_id = len(nodes) + nodes.append(TreeNode(node_id=node_id, parent_id=parent_id, depth=depth)) + if parent_id is not None: + nodes[parent_id].children.append(node_id) + return node_id + + +def parse_bp_encoded(bp: str) -> EncodedTree: + sequence = _normalize_parens(bp) + stack: list[tuple[int, int]] = [] + nodes: list[TreeNode] = [] + node_token_positions: dict[int, int] = {} + interval_targets: dict[int, int] = {} + + for pos, ch in enumerate(sequence): + if ch == "(": + parent_id = stack[-1][0] if stack else None + depth = len(stack) + node_id = _make_node(parent_id, depth, nodes) + node_token_positions[node_id] = pos + stack.append((node_id, pos)) + else: + if not stack: + raise ValueError(f"Unmatched ')' at position {pos}") + node_id, _ = stack.pop() + interval_targets[node_id] = pos + + if stack: + _, pos = stack[-1] + raise ValueError(f"Unmatched '(' at position {pos}") + + return EncodedTree( + mode="bp", + sequence=sequence, + nodes=nodes, + node_token_positions=node_token_positions, + interval_targets=interval_targets, + ) + + +def parse_louds_encoded(louds: str) -> EncodedTree: + """Parse a LOUDS sequence. + + Convention: the sequence starts with a single sentinel '(' followed by the + node blocks in BFS order. Each node v with d children contributes a block of + d+1 tokens: d opening parentheses followed by one closing parenthesis. A leaf + contributes just ')'. The sentinel is NOT part of any node's block. + + Accepts either parentheses ('(' = 1, ')' = 0) or a bit string ('1'/'0'); + in both cases the display sequence preserves the input form. + + The interval shown for node v spans its entire block [block_start, block_end]. + + Example for tree 0->[1,2], 1->[3,4,5], 2->[6,7], leaves 3-7: + ((()((()(())))) + ^ sentinel + ^^ ^ node0: block [1,3] + ^^^ ^ node1: block [4,7] + """ + bit_sequence, display_sequence = _normalize_louds(louds) + n = len(bit_sequence) + if n < 1 or bit_sequence[0] != "1": + raise ValueError("LOUDS sequence must start with a sentinel '1' (or '(' in paren form)") + + nodes: list[TreeNode] = [] + node_token_positions: dict[int, int] = {} + interval_targets: dict[int, int] = {} + + # Blocks appear in BFS order; use a FIFO. + pos = 1 # skip sentinel at position 0 + root_id = _make_node(parent_id=None, depth=0, nodes=nodes) + nodes_queue: deque[int] = deque([root_id]) + + while nodes_queue: + node_id = nodes_queue.popleft() + block_start = pos + + # Read d '1' tokens — each creates one child + while pos < n and bit_sequence[pos] == "1": + child_id = _make_node( + parent_id=node_id, + depth=nodes[node_id].depth + 1, + nodes=nodes, + ) + nodes_queue.append(child_id) + pos += 1 + + # Expect the closing '0' + if pos >= n or bit_sequence[pos] != "0": + raise ValueError( + f"LOUDS node {node_id} block starting at {block_start} " + f"is missing a terminating '0' at position {pos}" + ) + block_end = pos + pos += 1 + + # Shift by +1 for display because we insert a sentinel ')' at index 1 + node_token_positions[node_id] = block_start + 1 + interval_targets[node_id] = block_end + 1 + + if pos != n: + raise ValueError(f"LOUDS sequence has trailing data starting at position {pos}") + + # Build display sequence with an explicit sentinel ')' inserted after the opening '(' + # so the sentinel appears as a '()' pair rather than a bare '('. + sentinel_ch = display_sequence[0] # '(' or '1' + sentinel_close = ")" if sentinel_ch == "(" else "0" + augmented_display = sentinel_ch + sentinel_close + display_sequence[1:] + + return EncodedTree( + mode="louds", + sequence=augmented_display, + nodes=nodes, + node_token_positions=node_token_positions, + interval_targets=interval_targets, + ) + + +def parse_dfuds_encoded(dfuds: str) -> EncodedTree: + """Parse a DFUDS sequence. + + Convention: the sequence starts with a single sentinel '(' followed by the + node blocks in DFS pre-order. Each node v with d children contributes a block + of d+1 tokens: d opening parentheses followed by one closing parenthesis. A + leaf contributes just ')'. The sentinel is NOT part of any node's block. + + The interval shown for node v spans its entire block [block_start, block_end], + where block_start is the first token of the block and block_end is the closing ')'. + For a leaf this is a single-token interval. + + Example for tree 0->[1,5], 1->[2,3,4], 5->[6,7], leaves 2,3,4,6,7: + ((()((())))(())) + ^ sentinel + ^^ ^ node0: block [1,3] + ^^^ ^ node1: block [4,7] + ^ node2: block [8,8] ... + """ + sequence = _normalize_parens(dfuds) + n = len(sequence) + if n < 1 or sequence[0] != "(": + raise ValueError("DFUDS sequence must start with a sentinel '('") + + nodes: list[TreeNode] = [] + node_token_positions: dict[int, int] = {} + interval_targets: dict[int, int] = {} + + # Blocks appear in DFS pre-order (BFS ordering of block positions in the sequence). + # Use a FIFO so we process nodes in the same order their blocks appear. + pos = 1 # skip sentinel at position 0 + root_id = _make_node(parent_id=None, depth=0, nodes=nodes) + nodes_queue: deque[int] = deque([root_id]) + + while nodes_queue: + node_id = nodes_queue.popleft() + block_start = pos + + # Read d opening parens — each creates one child + while pos < n and sequence[pos] == "(": + child_id = _make_node( + parent_id=node_id, + depth=nodes[node_id].depth + 1, + nodes=nodes, + ) + nodes_queue.append(child_id) + pos += 1 + + # Expect the closing ')' + if pos >= n or sequence[pos] != ")": + raise ValueError( + f"DFUDS node {node_id} block starting at {block_start} " + f"is missing a terminating ')' at position {pos}" + ) + block_end = pos + pos += 1 + + node_token_positions[node_id] = block_start + interval_targets[node_id] = block_end + + if pos != n: + raise ValueError( + f"DFUDS sequence has trailing data starting at position {pos}" + ) + + return EncodedTree( + mode="dfuds", + sequence=sequence, + nodes=nodes, + node_token_positions=node_token_positions, + interval_targets=interval_targets, + ) + + +def parse_encoded_tree(sequence: str, mode: str) -> EncodedTree: + if mode == "bp": + return parse_bp_encoded(sequence) + if mode == "louds": + return parse_louds_encoded(sequence) + if mode == "dfuds": + return parse_dfuds_encoded(sequence) + raise ValueError(f"Unsupported mode: {mode}") + + +def _mode_label(mode: str) -> str: + return mode.upper() + + +def _mode_token_font(mode: str) -> str: + return "Times New Roman" if mode != "louds" else "Helvetica" + + +def _add_token_row( + graph: pydot.Dot, + sequence: str, + mode: str, + with_label: bool, + highlight_index: int | None, +) -> list[str]: + token_subgraph = pydot.Subgraph(f"{mode}_tokens", rank="sink") + token_names: list[str] = [] + fontname = _mode_token_font(mode) + + for i, ch in enumerate(sequence): + is_highlighted = i == highlight_index + token_name = f"token_{i}" + token_names.append(token_name) + token_subgraph.add_node( + pydot.Node( + token_name, + shape="plaintext", + label=ch, + fontname=(f"{fontname} Bold" if is_highlighted else fontname), + fontsize="34" if is_highlighted else "30", + fontcolor="#c62828" if is_highlighted else "black", + group=f"g{i}", + ) + ) + + if with_label: + token_subgraph.add_node( + pydot.Node( + "encoding_label", + shape="plaintext", + label=_mode_label(mode), + fontname="Times New Roman", + fontsize="26", + ) + ) + + graph.add_subgraph(token_subgraph) + + if with_label and token_names: + graph.add_edge( + pydot.Edge( + "encoding_label", + token_names[0], + style="invis", + weight="100", + ) + ) + + for i in range(len(token_names) - 1): + graph.add_edge( + pydot.Edge( + token_names[i], + token_names[i + 1], + style="invis", + weight="100", + ) + ) + + return token_names + + +def _make_graph_base(splines: str = "ortho") -> pydot.Dot: + return pydot.Dot( + graph_type="digraph", + rankdir="TB", + splines=splines, + nodesep="0.12", + ranksep="0.32", + dpi="220", + ) + + +def _add_single_level_intervals( + graph: pydot.Dot, + encoded: EncodedTree, + highlight_index: int | None, + node_shape: str, +) -> None: + ids_subgraph = pydot.Subgraph(f"{encoded.mode}_ids", rank="same") + + for node in encoded.nodes: + start_pos = encoded.node_token_positions[node.node_id] + end_pos = encoded.interval_targets.get(node.node_id) + touches_highlight = highlight_index in {start_pos, end_pos} + id_name = f"id_{node.node_id}" + + ids_subgraph.add_node( + pydot.Node( + id_name, + shape=node_shape, + fixedsize="true", + width="0.34", + height="0.34", + label=str(node.node_id), + fontname="Times New Roman", + fontsize="16", + penwidth="2.2" if touches_highlight else "1.2", + color="#c62828" if touches_highlight else "black", + fontcolor="#c62828" if touches_highlight else "black", + group=f"g{start_pos}", + ) + ) + + graph.add_edge( + pydot.Edge( + id_name, + f"token_{start_pos}", + style="invis", + weight="130", + ) + ) + + if end_pos is None: + continue + + graph.add_edge( + pydot.Edge( + id_name, + f"token_{end_pos}", + arrowhead="normal", + arrowsize="0.45", + penwidth="1.1", + color="#c62828" if touches_highlight else "#444444", + constraint="false", + ) + ) + + graph.add_subgraph(ids_subgraph) + + +def make_bp_graph( + encoded: EncodedTree, + with_label: bool, + highlight_index: int | None = None, +) -> pydot.Dot: + _validate_highlight_index(len(encoded.sequence), highlight_index) + graph = _make_graph_base(splines="ortho") + _add_token_row(graph, encoded.sequence, encoded.mode, with_label, highlight_index) + _add_single_level_intervals( + graph, + encoded=encoded, + highlight_index=highlight_index, + node_shape="circle", + ) + + return graph + + +def make_louds_graph( + encoded: EncodedTree, + with_label: bool, + highlight_index: int | None = None, +) -> pydot.Dot: + _validate_highlight_index(len(encoded.sequence), highlight_index) + graph = _make_graph_base(splines="ortho") + _add_token_row(graph, encoded.sequence, encoded.mode, with_label, highlight_index) + _add_single_level_intervals( + graph, + encoded=encoded, + highlight_index=highlight_index, + node_shape="box", + ) + + return graph + + +def make_dfuds_graph( + encoded: EncodedTree, + with_label: bool, + highlight_index: int | None = None, +) -> pydot.Dot: + _validate_highlight_index(len(encoded.sequence), highlight_index) + graph = _make_graph_base(splines="ortho") + _add_token_row(graph, encoded.sequence, encoded.mode, with_label, highlight_index) + _add_single_level_intervals( + graph, + encoded=encoded, + highlight_index=highlight_index, + node_shape="box", + ) + + return graph + + +def make_graph( + encoded: EncodedTree, + with_label: bool, + highlight_index: int | None = None, +) -> pydot.Dot: + if encoded.mode == "bp": + return make_bp_graph(encoded, with_label, highlight_index) + if encoded.mode == "louds": + return make_louds_graph(encoded, with_label, highlight_index) + if encoded.mode == "dfuds": + return make_dfuds_graph(encoded, with_label, highlight_index) + raise ValueError(f"Unsupported mode: {encoded.mode}") + + +def write_graph(graph: pydot.Dot, output: Path, fmt: str) -> None: + output.parent.mkdir(parents=True, exist_ok=True) + if fmt == "dot": + graph.write_raw(str(output)) + elif fmt == "png": + graph.write_png(str(output)) + elif fmt == "svg": + graph.write_svg(str(output)) + else: + raise ValueError(f"Unsupported output format: {fmt}") + + +def infer_format(output: Path, explicit_format: str | None) -> str: + if explicit_format: + return explicit_format + ext = output.suffix.lower().lstrip(".") + if ext in {"dot", "png", "svg"}: + return ext + return "png" + + +def write_highlight_sequence( + encoded: EncodedTree, + with_label: bool, + output_dir: Path, + prefix: str, + fmt: str, +) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + digits = max(2, len(str(len(encoded.sequence) - 1))) + + for i in range(len(encoded.sequence)): + graph = make_graph(encoded, with_label, highlight_index=i) + frame_path = output_dir / f"{prefix}_{i:0{digits}d}.{fmt}" + write_graph(graph, frame_path, fmt) + print(f"[saved] {frame_path} ({fmt})") + + print(f"[info] generated_frames={len(encoded.sequence)}") + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Draw a standalone BP, LOUDS, or DFUDS representation." + ) + parser.add_argument( + "sequence", + help=( + "Input sequence interpreted according to --mode. " + "For --mode louds, accepts either 0/1 or equivalent parentheses tokens." + ), + ) + parser.add_argument( + "--mode", + choices=["bp", "louds", "dfuds"], + default="bp", + help="Input encoding mode. Default: %(default)s", + ) + parser.add_argument( + "--output", + type=Path, + default=Path("tree_representation.png"), + help="Output file path. Default: %(default)s", + ) + parser.add_argument( + "--format", + choices=["dot", "png", "svg"], + default=None, + help="Output format. If omitted, inferred from --output extension.", + ) + parser.add_argument( + "--no-label", + dest="with_label", + action="store_false", + default=True, + help="Hide the encoding label on the left.", + ) + parser.add_argument( + "--highlight-index", + type=int, + default=None, + help=( + "0-based token index to emphasize in red. " + "Example: --highlight-index 5" + ), + ) + parser.add_argument( + "--sequence-dir", + type=Path, + default=None, + help=( + "Output directory for a full highlight sequence (one frame per token). " + "If set, --output is ignored." + ), + ) + parser.add_argument( + "--sequence-prefix", + default="encoding_step", + help="Filename prefix for sequence frames. Default: %(default)s", + ) + args = parser.parse_args() + + encoded = parse_encoded_tree(args.sequence, args.mode) + + if args.sequence_dir is not None: + fmt = args.format or "png" + write_highlight_sequence( + encoded=encoded, + with_label=args.with_label, + output_dir=args.sequence_dir, + prefix=args.sequence_prefix, + fmt=fmt, + ) + print(f"[info] mode={encoded.mode} length={len(encoded.sequence)} nodes={len(encoded.nodes)}") + return 0 + + graph = make_graph(encoded, args.with_label, args.highlight_index) + fmt = infer_format(args.output, args.format) + write_graph(graph, args.output, fmt) + + print(f"[saved] {args.output} ({fmt})") + if args.highlight_index is not None: + print(f"[info] highlighted_index={args.highlight_index}") + print(f"[info] mode={encoded.mode} length={len(encoded.sequence)} nodes={len(encoded.nodes)}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/draw_bp_tree.py b/scripts/draw_bp_tree.py new file mode 100644 index 0000000..7559cf4 --- /dev/null +++ b/scripts/draw_bp_tree.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 +"""Draw a rooted tree from a succinct tree encoding using pydot. + +Supported modes (--mode): +- bp : balanced parentheses in DFS order, one pair per node. +- louds : level-order unary degree sequence with a leading sentinel '('. + Each node contributes d opening parentheses + ')'. BFS order. +- dfuds : depth-first unary degree sequence with a leading sentinel '('. + Each node contributes d opening parentheses + ')'. DFS order. + +The --sequence-dir flag produces one frame per node, highlighting that node. +The --highlight-node flag highlights a single node by 0-based id. + +Examples: + uv run --no-project --with pydot python scripts/draw_bp_tree.py \\ + "((()()())(()()))" --mode bp --output bp_tree.svg --format svg + uv run --no-project --with pydot python scripts/draw_bp_tree.py \\ + "((()((()(())))))" --mode louds --sequence-dir out/louds --sequence-prefix tree + uv run --no-project --with pydot python scripts/draw_bp_tree.py \\ + "((()((())))(()))" --mode dfuds --output dfuds_tree.png +""" + +from __future__ import annotations + +import argparse +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path + +import pydot + + +@dataclass +class Node: + node_id: int + parent_id: int | None + depth: int + children: list[int] = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Normalisation helpers +# --------------------------------------------------------------------------- + +def _normalize_parens(text: str) -> str: + filtered = "".join(ch for ch in text if not ch.isspace()) + invalid = sorted({ch for ch in filtered if ch not in "()"}) + if invalid: + chars = ", ".join(repr(ch) for ch in invalid) + raise ValueError(f"sequence contains invalid characters: {chars}") + if not filtered: + raise ValueError("sequence is empty") + return filtered + + +def _normalize_louds(text: str) -> tuple[str, str]: + """Return (bit_sequence, display_sequence). + + Accepts '0'/'1' bits or '('/')' parentheses (where '(' = 1, ')' = 0). + """ + filtered = "".join(ch for ch in text if not ch.isspace()) + if not filtered: + raise ValueError("sequence is empty") + if all(ch in "01" for ch in filtered): + return filtered, filtered + if all(ch in "()" for ch in filtered): + bits = "".join("1" if ch == "(" else "0" for ch in filtered) + return bits, filtered + invalid = sorted({ch for ch in filtered if ch not in "01()"}) + if invalid: + chars = ", ".join(repr(ch) for ch in invalid) + raise ValueError(f"LOUDS sequence contains invalid characters: {chars}") + raise ValueError("LOUDS sequence must use either only '0/1' or only '(/' and ')' tokens") + + +def _make_node(parent_id: int | None, depth: int, nodes: list[Node]) -> int: + node_id = len(nodes) + nodes.append(Node(node_id=node_id, parent_id=parent_id, depth=depth)) + if parent_id is not None: + nodes[parent_id].children.append(node_id) + return node_id + + +# --------------------------------------------------------------------------- +# Parsers +# --------------------------------------------------------------------------- + +def parse_bp(sequence: str) -> list[Node]: + """Parse a BP string into a node list.""" + seq = _normalize_parens(sequence) + nodes: list[Node] = [] + stack: list[int] = [] + + for pos, ch in enumerate(seq): + if ch == "(": + parent_id = stack[-1] if stack else None + node_id = _make_node(parent_id, len(stack), nodes) + stack.append(node_id) + else: + if not stack: + raise ValueError(f"Unmatched ')' at position {pos}") + stack.pop() + + if stack: + raise ValueError(f"Unmatched '(' — {len(stack)} unclosed node(s)") + return nodes + + +def parse_louds(sequence: str) -> list[Node]: + """Parse a LOUDS parenthesis sequence with a leading sentinel '('.""" + bit_seq, _ = _normalize_louds(sequence) + n = len(bit_seq) + if n < 1 or bit_seq[0] != "1": + raise ValueError("LOUDS sequence must start with a sentinel '1' (or '(' in paren form)") + + nodes: list[Node] = [] + pos = 1 # skip sentinel + root_id = _make_node(None, 0, nodes) + queue: deque[int] = deque([root_id]) + + while queue: + node_id = queue.popleft() + while pos < n and bit_seq[pos] == "1": + child_id = _make_node(node_id, nodes[node_id].depth + 1, nodes) + queue.append(child_id) + pos += 1 + if pos >= n or bit_seq[pos] != "0": + raise ValueError(f"LOUDS node {node_id} missing terminating '0' at position {pos}") + pos += 1 # consume '0' + + if pos != n: + raise ValueError(f"LOUDS sequence has trailing data starting at position {pos}") + return nodes + + +def parse_dfuds(sequence: str) -> list[Node]: + """Parse a DFUDS parenthesis sequence with a leading sentinel '('.""" + seq = _normalize_parens(sequence) + n = len(seq) + if n < 1 or seq[0] != "(": + raise ValueError("DFUDS sequence must start with a sentinel '('") + + nodes: list[Node] = [] + pos = 1 # skip sentinel + root_id = _make_node(None, 0, nodes) + queue: deque[int] = deque([root_id]) + + while queue: + node_id = queue.popleft() + while pos < n and seq[pos] == "(": + child_id = _make_node(node_id, nodes[node_id].depth + 1, nodes) + queue.append(child_id) + pos += 1 + if pos >= n or seq[pos] != ")": + raise ValueError(f"DFUDS node {node_id} missing terminating ')' at position {pos}") + pos += 1 # consume ')' + + if pos != n: + raise ValueError(f"DFUDS sequence has trailing data starting at position {pos}") + return nodes + + +def parse_nodes(sequence: str, mode: str) -> list[Node]: + if mode == "bp": + return parse_bp(sequence) + if mode == "louds": + return parse_louds(sequence) + if mode == "dfuds": + return parse_dfuds(sequence) + raise ValueError(f"Unsupported mode: {mode}") + + +# --------------------------------------------------------------------------- +# Graph building +# --------------------------------------------------------------------------- + +def _validate_highlight_node(nodes: list[Node], highlight_node: int | None) -> None: + if highlight_node is None: + return + if highlight_node < 0 or highlight_node >= len(nodes): + raise ValueError( + f"highlight node {highlight_node} out of range [0, {len(nodes) - 1}]" + ) + + +def make_graph( + nodes: list[Node], + label_mode: str, + highlight_node: int | None = None, +) -> pydot.Dot: + _validate_highlight_node(nodes, highlight_node) + + graph = pydot.Dot( + graph_type="digraph", + rankdir="TB", + dpi="220", + nodesep="0.3", + ranksep="0.45", + splines="true", + outputorder="edgesfirst", + ) + graph.set_node_defaults( + shape="circle", + fontname="Helvetica", + fontsize="20", + width="0.7", + height="0.7", + penwidth="2.0", + margin="0.06,0.04", + ) + graph.set_edge_defaults(color="#444444", penwidth="1.8", arrowsize="0.6", arrowhead="none") + + for node in nodes: + if label_mode == "id": + label = str(node.node_id) + elif label_mode == "depth": + label = f"{node.node_id}\\nd={node.depth}" + else: + label = "" + + node_style: dict[str, str] = {} + if highlight_node == node.node_id: + node_style = { + "color": "#c62828", + "fontcolor": "#c62828", + "fillcolor": "#fde7e7", + "style": "filled,bold", + "penwidth": "3.2", + } + + graph.add_node(pydot.Node(str(node.node_id), label=label, **node_style)) + + for node in nodes: + if node.parent_id is not None: + graph.add_edge(pydot.Edge(str(node.parent_id), str(node.node_id))) + + return graph + + +# --------------------------------------------------------------------------- +# Output helpers +# --------------------------------------------------------------------------- + +def write_graph(graph: pydot.Dot, output: Path, fmt: str) -> None: + output.parent.mkdir(parents=True, exist_ok=True) + if fmt == "dot": + graph.write_raw(str(output)) + elif fmt == "png": + graph.write_png(str(output)) + elif fmt == "svg": + graph.write_svg(str(output)) + else: + raise ValueError(f"Unsupported output format: {fmt}") + + +def infer_format_from_output(output: Path, explicit_format: str | None) -> str: + if explicit_format is not None: + return explicit_format + ext = output.suffix.lower().lstrip(".") + if ext in {"dot", "png", "svg"}: + return ext + return "png" + + +def write_highlight_sequence( + nodes: list[Node], + label_mode: str, + output_dir: Path, + prefix: str, + fmt: str, +) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + digits = max(2, len(str(len(nodes) - 1))) + + for node in nodes: + graph = make_graph(nodes, label_mode, highlight_node=node.node_id) + output = output_dir / f"{prefix}_{node.node_id:0{digits}d}.{fmt}" + write_graph(graph, output, fmt) + print(f"[saved] {output} ({fmt})") + + print(f"[info] generated_frames={len(nodes)}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main() -> int: + parser = argparse.ArgumentParser( + description="Draw a rooted tree from a succinct tree encoding." + ) + parser.add_argument( + "sequence", + help="Input sequence interpreted according to --mode.", + ) + parser.add_argument( + "--mode", + choices=["bp", "louds", "dfuds"], + default="bp", + help="Input encoding mode. Default: %(default)s", + ) + parser.add_argument( + "--output", + type=Path, + default=Path("tree.png"), + help="Output file path. Default: %(default)s", + ) + parser.add_argument( + "--format", + choices=["dot", "png", "svg"], + default=None, + help="Output format. If omitted, inferred from --output extension.", + ) + parser.add_argument( + "--label-mode", + choices=["none", "id", "depth"], + default="id", + help="Node labels: none, numeric id, or id+depth. Default: %(default)s", + ) + parser.add_argument( + "--highlight-node", + type=int, + default=None, + help="0-based node id to emphasise in red.", + ) + parser.add_argument( + "--sequence-dir", + type=Path, + default=None, + help=( + "Output directory for one-frame-per-node highlighting sequence. " + "If set, --output is ignored." + ), + ) + parser.add_argument( + "--sequence-prefix", + default="tree", + help="Filename prefix for sequence frames. Default: %(default)s", + ) + args = parser.parse_args() + + nodes = parse_nodes(args.sequence, args.mode) + + if args.sequence_dir is not None: + fmt = args.format or "png" + write_highlight_sequence( + nodes=nodes, + label_mode=args.label_mode, + output_dir=args.sequence_dir, + prefix=args.sequence_prefix, + fmt=fmt, + ) + print(f"[info] mode={args.mode} nodes={len(nodes)}") + return 0 + + graph = make_graph(nodes, args.label_mode, args.highlight_node) + fmt = infer_format_from_output(args.output, args.format) + write_graph(graph, args.output, fmt) + + print(f"[saved] {args.output} ({fmt})") + if args.highlight_node is not None: + print(f"[info] highlighted_node={args.highlight_node}") + print(f"[info] mode={args.mode} nodes={len(nodes)}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/kilo-postprocess.ts b/scripts/kilo-postprocess.ts new file mode 100644 index 0000000..3957884 --- /dev/null +++ b/scripts/kilo-postprocess.ts @@ -0,0 +1,291 @@ +import fs from "node:fs"; + +type JsonObject = Record; + +const MAX_TOOL_OUTPUT_LINES = 12; +const MAX_TOOL_LINES_IN_FINAL_OUTPUT = 80; + +function stripAnsi(input: string): string { + return input.replace(/\x1B\[[0-9;]*[A-Za-z]/g, ""); +} + +function asObject(value: unknown): JsonObject | null { + if (typeof value === "object" && value !== null) { + return value as JsonObject; + } + return null; +} + +function asText(value: unknown): string { + return typeof value === "string" ? value : ""; +} + +function asDisplayText(value: unknown): string { + if (typeof value === "string") { + return value; + } + + if (typeof value === "number" || typeof value === "boolean") { + return String(value); + } + + if (value === null || value === undefined) { + return ""; + } + + try { + return JSON.stringify(value); + } catch { + return ""; + } +} + +function normalizeLines(text: string): string[] { + return stripAnsi(text) + .replace(/\r/g, "") + .split("\n") + .map((line) => line.trimEnd()) + .filter((line) => line.length > 0); +} + +function isToolEvent(event: JsonObject): boolean { + const type = asText(event.type); + if (type === "tool_use" || type === "tool_start" || type === "tool_end" || type === "tool_result") { + return true; + } + + return type === "say" && asText(event.say) === "tool"; +} + +function eventTimestamp(event: JsonObject): string { + const raw = event.timestamp; + if (typeof raw !== "number" || !Number.isFinite(raw)) { + return "????????????"; + } + return new Date(raw).toISOString().slice(11, 23); +} + +function summarizeEvent(event: JsonObject): string { + const type = asText(event.type); + + if (type === "welcome") { + return "session started"; + } + + if (type === "error") { + const message = asText(event.message) || asText(event.content); + return message ? `error: ${message}` : "error"; + } + + if (type === "text") { + const part = asObject(event.part); + const text = part ? asText(part.text).trim() : ""; + if (text) { + return `assistant: ${text}`; + } + return "assistant text"; + } + + if (type === "tool_start") { + const tool = asText(event.tool) || asText(event.name) || "unknown"; + return `tool start: ${tool}`; + } + + if (type === "tool_use") { + const part = asObject(event.part); + const tool = (part && asText(part.tool)) || asText(event.tool) || asText(event.name) || "unknown"; + const state = part ? asObject(part.state) : null; + const status = state ? asText(state.status) : ""; + const input = state ? asObject(state.input) : null; + const description = input ? asText(input.description).trim() : ""; + const command = input ? asText(input.command).trim() : ""; + const label = description || command; + + if (status && label) { + return `tool ${tool} ${status}: ${label}`; + } + if (status) { + return `tool ${tool} ${status}`; + } + if (label) { + return `tool ${tool}: ${label}`; + } + return `tool ${tool}`; + } + + if (type === "tool_end" || type === "tool_result") { + const tool = asText(event.tool) || asText(event.name) || "unknown"; + return `tool end: ${tool}`; + } + + if (type === "say") { + const say = asText(event.say); + const content = stripAnsi(asText(event.content)).trim(); + + if (say === "text") { + return content ? `assistant: ${content}` : "assistant text"; + } + + if (say === "reasoning") { + const partial = event.partial === true; + if (partial) { + return "reasoning (partial)"; + } + return content ? `reasoning: ${content}` : "reasoning"; + } + + if (say === "tool") { + return content ? `tool: ${content}` : "tool"; + } + + if (say === "api_req_started") { + const metadata = asObject(event.metadata); + const protocol = metadata ? asText(metadata.apiProtocol) : ""; + return protocol ? `api request started: ${protocol}` : "api request started"; + } + + return say ? `say/${say}` : "say"; + } + + if (type) { + return `event: ${type}`; + } + + return "event"; +} + +function extractToolOutputLines(event: JsonObject): string[] { + if (asText(event.type) !== "tool_use") { + return []; + } + + const part = asObject(event.part); + if (!part) { + return []; + } + + const state = asObject(part.state); + if (!state) { + return []; + } + + const metadata = asObject(state.metadata); + const rawOutput = asDisplayText(state.output) || (metadata ? asDisplayText(metadata.output) : ""); + if (!rawOutput.trim()) { + return []; + } + + const outputLines = normalizeLines(rawOutput); + if (outputLines.length === 0) { + return []; + } + + const limited = outputLines.slice(0, MAX_TOOL_OUTPUT_LINES); + const result = ["tool output:"]; + for (const line of limited) { + result.push(` ${line}`); + } + + if (outputLines.length > MAX_TOOL_OUTPUT_LINES) { + result.push(` ... (${outputLines.length - MAX_TOOL_OUTPUT_LINES} more lines)`); + } + + return result; +} + +function extractText(event: JsonObject): string { + const type = asText(event.type); + + if (type === "text") { + const part = asObject(event.part); + return part ? asText(part.text).trim() : ""; + } + + if (type === "say" && asText(event.say) === "text") { + return stripAnsi(asText(event.content)).trim(); + } + + return ""; +} + +function main(): void { + const eventsPath = process.argv[2] ?? "kilo-events.log"; + const outputPath = process.argv[3] ?? "kilo-output.log"; + const readablePath = process.argv[4] ?? "kilo-readable.log"; + + const raw = fs.readFileSync(eventsPath, "utf8"); + const lines = raw.split(/\r?\n/); + + const textParts: string[] = []; + const readable: string[] = []; + const toolTrace: string[] = []; + + for (const line of lines) { + if (!line.trim()) { + continue; + } + + try { + const parsed = JSON.parse(line) as unknown; + const event = asObject(parsed); + if (!event) { + continue; + } + + const text = extractText(event); + if (text) { + textParts.push(text); + } + + const timestamp = eventTimestamp(event); + const summary = summarizeEvent(event); + const summaryLine = `[${timestamp}] ${summary}`; + readable.push(summaryLine); + + const toolOutputLines = extractToolOutputLines(event); + for (const line of toolOutputLines) { + readable.push(`[${timestamp}] ${line}`); + } + + if (isToolEvent(event)) { + toolTrace.push(summaryLine); + for (const line of toolOutputLines) { + toolTrace.push(`[${timestamp}] ${line}`); + } + } + } catch { + const stripped = stripAnsi(line).trim(); + if (stripped) { + readable.push(stripped); + } + } + } + + let finalOutput = ""; + if (textParts.length > 0) { + finalOutput = textParts[textParts.length - 1]; + } else { + const fallback = stripAnsi(raw) + .split(/\r?\n/) + .map((item) => item.trim()) + .filter((item) => item.length > 0); + finalOutput = fallback.length > 0 ? fallback[fallback.length - 1] : ""; + } + + if (toolTrace.length > 0) { + const limitedToolTrace = toolTrace.slice(0, MAX_TOOL_LINES_IN_FINAL_OUTPUT); + if (toolTrace.length > MAX_TOOL_LINES_IN_FINAL_OUTPUT) { + limitedToolTrace.push( + `... (${toolTrace.length - MAX_TOOL_LINES_IN_FINAL_OUTPUT} more tool lines)` + ); + } + + const toolSection = [`Tool calls:`, ...limitedToolTrace].join("\n"); + finalOutput = finalOutput ? `${finalOutput}\n\n${toolSection}` : toolSection; + } + + fs.writeFileSync(outputPath, `${finalOutput}\n`, "utf8"); + fs.writeFileSync(readablePath, `${readable.join("\n")}\n`, "utf8"); +} + +main(); diff --git a/src/benchmarks/benchmarks.cpp b/src/benchmarks/benchmarks.cpp index a0d06f2..11acbf6 100644 --- a/src/benchmarks/benchmarks.cpp +++ b/src/benchmarks/benchmarks.cpp @@ -11,7 +11,14 @@ #include #include +/** + * In literature bitvector length is usually measured up to 2^35, for simplicity + * were measure up to 2^30 where the time is mainly dominated by the main memory + * accesses. + */ + constexpr size_t kBenchmarkRandomCopies = 8; +constexpr double warmup_time = 0.5; #ifdef _WIN32 #include @@ -89,10 +96,6 @@ static void BM_RankNonInterleaved(benchmark::State& state) { auto bits_as_words = bits.As64BitInts(); PrepareRandomBits50pFill(bits_as_words); pixie::BitVector bv(bits_as_words, n); -#ifdef PIXIE_DIAGNOSTICS - bv.memory_report(); -#endif - std::mt19937_64 rng(42); for (auto _ : state) { uint64_t pos = rng() % n; @@ -292,90 +295,103 @@ static void BM_SelectZeroNonInterleaved87p5PercentFill( BENCHMARK(BM_RankInterleaved) ->ArgNames({"n"}) ->RangeMultiplier(4) - ->Range(8, 1ull << 34) + ->Range(1ull << 10, 1ull << 30) ->Iterations(10000000) + ->MinWarmUpTime(warmup_time) ->Repetitions(10); BENCHMARK(BM_RankNonInterleaved) ->ArgNames({"n"}) ->RangeMultiplier(4) - ->Range(8, 1ull << 34) + ->Range(1ull << 10, 1ull << 30) ->Iterations(10000000) + ->MinWarmUpTime(warmup_time) ->Repetitions(10); BENCHMARK(BM_RankZeroNonInterleaved) ->ArgNames({"n"}) ->RangeMultiplier(4) - ->Range(8, 1ull << 34) + ->Range(1ull << 10, 1ull << 30) ->Iterations(10000000) + ->MinWarmUpTime(warmup_time) ->Repetitions(10); BENCHMARK(BM_SelectNonInterleaved) ->ArgNames({"n"}) ->RangeMultiplier(4) - ->Range(8, 1ull << 34) + ->Range(1ull << 10, 1ull << 30) ->Iterations(5000000) + ->MinWarmUpTime(warmup_time) ->Repetitions(10); BENCHMARK(BM_SelectZeroNonInterleaved) ->ArgNames({"n"}) ->RangeMultiplier(4) - ->Range(8, 1ull << 34) + ->Range(1ull << 10, 1ull << 30) ->Iterations(5000000) + ->MinWarmUpTime(warmup_time) ->Repetitions(10); BENCHMARK(BM_RankNonInterleaved12p5PercentFill) ->ArgNames({"n"}) ->RangeMultiplier(4) - ->Range(8, 1ull << 34) + ->Range(1ull << 10, 1ull << 30) ->Iterations(10000000) + ->MinWarmUpTime(warmup_time) ->Repetitions(10); BENCHMARK(BM_RankZeroNonInterleaved12p5PercentFill) ->ArgNames({"n"}) ->RangeMultiplier(4) - ->Range(8, 1ull << 34) + ->Range(1ull << 10, 1ull << 30) ->Iterations(10000000) + ->MinWarmUpTime(warmup_time) ->Repetitions(10); BENCHMARK(BM_SelectNonInterleaved12p5PercentFill) ->ArgNames({"n"}) ->RangeMultiplier(4) - ->Range(8, 1ull << 34) + ->Range(1ull << 10, 1ull << 30) ->Iterations(5000000) + ->MinWarmUpTime(warmup_time) ->Repetitions(10); BENCHMARK(BM_SelectZeroNonInterleaved12p5PercentFill) ->ArgNames({"n"}) ->RangeMultiplier(4) - ->Range(8, 1ull << 34) + ->Range(1ull << 10, 1ull << 30) ->Iterations(5000000) + ->MinWarmUpTime(warmup_time) ->Repetitions(10); BENCHMARK(BM_RankNonInterleaved87p5PercentFill) ->ArgNames({"n"}) ->RangeMultiplier(4) - ->Range(8, 1ull << 34) + ->Range(1ull << 10, 1ull << 30) ->Iterations(10000000) + ->MinWarmUpTime(warmup_time) ->Repetitions(10); BENCHMARK(BM_RankZeroNonInterleaved87p5PercentFill) ->ArgNames({"n"}) ->RangeMultiplier(4) - ->Range(8, 1ull << 34) + ->Range(1ull << 10, 1ull << 30) ->Iterations(10000000) + ->MinWarmUpTime(warmup_time) ->Repetitions(10); BENCHMARK(BM_SelectNonInterleaved87p5PercentFill) ->ArgNames({"n"}) ->RangeMultiplier(4) - ->Range(8, 1ull << 34) + ->Range(1ull << 10, 1ull << 30) ->Iterations(5000000) + ->MinWarmUpTime(warmup_time) ->Repetitions(10); BENCHMARK(BM_SelectZeroNonInterleaved87p5PercentFill) ->ArgNames({"n"}) ->RangeMultiplier(4) - ->Range(8, 1ull << 34) + ->Range(1ull << 10, 1ull << 30) ->Iterations(5000000) + ->MinWarmUpTime(warmup_time) ->Repetitions(10);