From 8f1a0f8dc280daf76b75272797c73db0f1fe59cc Mon Sep 17 00:00:00 2001 From: Muneerali199 Date: Fri, 6 Mar 2026 15:15:04 +0530 Subject: [PATCH 1/7] feat: add streaming parser and incremental Merkle tree --- openverifiablellm/benchmark.py | 306 ++++++++++++++++++++ openverifiablellm/incremental_merkle.py | 282 ++++++++++++++++++ openverifiablellm/streaming_utils.py | 175 ++++++++++++ tests/test_merkle.py | 364 ++++++++++++++++++++++++ 4 files changed, 1127 insertions(+) create mode 100644 openverifiablellm/benchmark.py create mode 100644 openverifiablellm/incremental_merkle.py create mode 100644 openverifiablellm/streaming_utils.py create mode 100644 tests/test_merkle.py diff --git a/openverifiablellm/benchmark.py b/openverifiablellm/benchmark.py new file mode 100644 index 0000000..ef4d6da --- /dev/null +++ b/openverifiablellm/benchmark.py @@ -0,0 +1,306 @@ +""" +benchmark.py +============ +Before-vs-After resource consumption analysis. + +Compares the **legacy approach** (load everything into memory, hash at once) +against the **new streaming approach** (O(1) memory generator + Incremental +Merkle Tree) on the same Wikipedia dump. + +Metrics captured +---------------- +* Wall-clock execution time (seconds) via ``time.perf_counter``. +* Peak heap allocation (MB) via ``tracemalloc``. + +Output +------ +Prints a GitHub-flavored Markdown table to stdout so the result can be +copy-pasted directly into a Pull Request description. + +Usage +----- + # Download first (≈350 MB): + # wget https://dumps.wikimedia.org/simplewiki/20260201/simplewiki-20260201-pages-articles-multistream.xml.bz2 + + python -m openverifiablellm.benchmark simplewiki-20260201-pages-articles-multistream.xml.bz2 + + # Or via the scripts/ helper: + python scripts/benchmark.py +""" + +import argparse +import bz2 +import gc +import hashlib +import logging +import sys +import time +import tracemalloc +import xml.etree.ElementTree as ET +from pathlib import Path +from typing import List, Optional, Tuple + +from openverifiablellm.incremental_merkle import IncrementalMerkleTree +from openverifiablellm.streaming_utils import stream_text_from_xml +from openverifiablellm.utils import clean_wikitext + +logging.basicConfig(level=logging.WARNING) +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Type alias for a benchmark result row +# --------------------------------------------------------------------------- +BenchmarkResult = Tuple[ + str, # approach label + float, # wall-clock seconds + float, # peak RAM in MB + str, # root hash (hex) +] + + +# --------------------------------------------------------------------------- +# Helper: convert tracemalloc peak bytes → MB +# --------------------------------------------------------------------------- +def _bytes_to_mb(n_bytes: int) -> float: + return n_bytes / (1024 * 1024) + + +# =========================================================================== +# APPROACH 1 — "Old Way" (in-memory) +# =========================================================================== + +def _run_old_way(file_path: Path) -> BenchmarkResult: + """ + Legacy approach: decompress the entire dump, collect ALL article texts + into a Python list, then build a standard batch Merkle tree from the list. + + Memory profile: O(N) — every article text lives in RAM simultaneously. + Time profile : O(N) for loading + O(N log N) for tree construction. + """ + gc.collect() + tracemalloc.start() + t_start = time.perf_counter() + + # ----- Step 1: load all texts into memory ----- + all_texts: List[str] = [] + + with bz2.open(file_path, "rb") as raw: + context = ET.iterparse(raw, events=("end",)) + for _event, elem in context: + if elem.tag.endswith("page"): + text_elem = elem.find(".//{*}text") + if text_elem is None: + text_elem = elem.find(".//text") + if text_elem is not None and text_elem.text: + cleaned = clean_wikitext(text_elem.text) + if cleaned: + all_texts.append(cleaned) + # NOTE: No elem.clear() — intentionally simulating the + # old code that leaks every parsed element into memory. + + # ----- Step 2: build Merkle tree from the in-memory list ----- + # Hash each article text to a leaf + leaves: List[bytes] = [ + hashlib.sha256(t.encode("utf-8")).digest() for t in all_texts + ] + + # Batch construction: classic bottom-up Merkle tree + if not leaves: + root_hex = hashlib.sha256(b"").hexdigest() + else: + current_level = leaves + while len(current_level) > 1: + next_level: List[bytes] = [] + for i in range(0, len(current_level), 2): + left = current_level[i] + right = current_level[i + 1] if i + 1 < len(current_level) else left + next_level.append(hashlib.sha256(left + right).digest()) + current_level = next_level + root_hex = current_level[0].hex() + + t_end = time.perf_counter() + _, peak_bytes = tracemalloc.get_traced_memory() + tracemalloc.stop() + + return ( + "Old Way (in-memory)", + round(t_end - t_start, 3), + round(_bytes_to_mb(peak_bytes), 2), + root_hex, + ) + + +# =========================================================================== +# APPROACH 2 — "New Way" (streaming) +# =========================================================================== + +def _run_new_way(file_path: Path) -> BenchmarkResult: + """ + New streaming approach: yield one article at a time from the generator + and feed it directly into the IncrementalMerkleTree. + + Memory profile: O(log N) — only the Merkle frontier is kept in RAM. + Time profile : O(N log N) — but with vastly lower constant factors + because no large list allocation occurs. + """ + gc.collect() + tracemalloc.start() + t_start = time.perf_counter() + + tree = IncrementalMerkleTree() + + for article_text in stream_text_from_xml(str(file_path)): + tree.append_leaf(article_text) + + root_hex: str = tree.get_root_hash() or hashlib.sha256(b"").hexdigest() + + t_end = time.perf_counter() + _, peak_bytes = tracemalloc.get_traced_memory() + tracemalloc.stop() + + return ( + "New Way (streaming)", + round(t_end - t_start, 3), + round(_bytes_to_mb(peak_bytes), 2), + root_hex, + ) + + +# =========================================================================== +# Reporting: GitHub-Flavored Markdown table +# =========================================================================== + +def _render_markdown_table( + old: BenchmarkResult, + new: BenchmarkResult, + file_name: str, +) -> str: + """ + Render a GFM Markdown table suitable for direct use in a GitHub PR. + + Calculates speed-up and RAM reduction ratios and appends a legend. + """ + label_old, time_old, ram_old, hash_old = old + label_new, time_new, ram_new, hash_new = new + + # Guard against division by zero on extremely fast runs + time_ratio = (time_old / time_new) if time_new > 0 else float("inf") + ram_ratio = (ram_old / ram_new) if ram_new > 0 else float("inf") + + hashes_match = (hash_old == hash_new) + hash_verdict = "YES — identical root hash" if hashes_match else "NO — MISMATCH (investigate!)" + + lines = [ + "", + "## Benchmark Results", + "", + f"> **Input file:** `{file_name}` ", + f"> **Root hashes match:** {hash_verdict}", + "", + "| Metric | Old Way (in-memory) | New Way (streaming) | Improvement |", + "|-------------------------------|--------------------:|--------------------:|--------------------|", + f"| Wall-clock time (s) | `{time_old:>10.3f}` | `{time_new:>10.3f}` | **{time_ratio:,.1f}× faster** |", + f"| Peak RAM usage (MB) | `{ram_old:>10.2f}` | `{ram_new:>10.2f}` | **{ram_ratio:,.1f}× less RAM** |", + f"| Root hash | `{hash_old[:16]}…` | `{hash_new[:16]}…` | {'Match' if hashes_match else 'MISMATCH'} |", + "", + "### Notes", + "- *Peak RAM* is measured with `tracemalloc` (Python heap only; does not include", + " OS-level buffers or the bzip2 decompressor's internal state).", + "- *Wall-clock time* is measured with `time.perf_counter` on a single run.", + " For publication-quality numbers repeat 3× and report median ± std-dev.", + "- The Old Way intentionally omits `elem.clear()` to reproduce the OOM behaviour.", + "- The New Way uses `stream_text_from_xml` + `IncrementalMerkleTree` from this PR.", + "", + ] + return "\n".join(lines) + + +def _render_terminal_table( + old: BenchmarkResult, + new: BenchmarkResult, + file_name: str, +) -> str: + """Plain-text box table for terminal output (complements the Markdown table).""" + label_old, time_old, ram_old, hash_old = old + label_new, time_new, ram_new, hash_new = new + time_ratio = (time_old / time_new) if time_new > 0 else float("inf") + ram_ratio = (ram_old / ram_new) if ram_new > 0 else float("inf") + hashes_match = hash_old == hash_new + + w = 90 + sep = "─" * w + + def row(col1: str, col2: str, col3: str, col4: str = "") -> str: + return f"│ {col1:<28} │ {col2:>18} │ {col3:>18} │ {col4:<14} │" + + lines = [ + f"┌{sep}┐", + f"│{'BEFORE vs. AFTER — ' + file_name:^{w}}│", + f"├{sep}┤", + row("Metric", "Old Way", "New Way", "Improvement"), + f"├{sep}┤", + row("Wall-clock time (s)", f"{time_old:.3f} s", f"{time_new:.3f} s", f"{time_ratio:,.1f}x faster"), + row("Peak RAM (MB)", f"{ram_old:.2f} MB", f"{ram_new:.2f} MB", f"{ram_ratio:,.1f}x less"), + row("Root hashes match", "", "", "YES" if hashes_match else "NO — MISMATCH"), + f"└{sep}┘", + ] + return "\n".join(lines) + + +# =========================================================================== +# Main entry point +# =========================================================================== + +def run_benchmark(file_path: str) -> None: + """ + Execute both benchmarks sequentially and print the results. + + Parameters + ---------- + file_path: + Path to the Wikipedia ``.xml.bz2`` dump. + """ + path = Path(file_path) + if not path.exists(): + print(f"[ERROR] File not found: {path}", file=sys.stderr) + sys.exit(1) + + print(f"\nRunning OLD WAY benchmark on: {path.name}") + print(" (This may take several minutes and use significant RAM …)\n") + old_result = _run_old_way(path) + print(f" Done. Time={old_result[1]:.3f}s Peak RAM={old_result[2]:.2f} MB") + + print(f"\nRunning NEW WAY benchmark on: {path.name}") + print(" (Streaming — should use constant, minimal RAM …)\n") + new_result = _run_new_way(path) + print(f" Done. Time={new_result[1]:.3f}s Peak RAM={new_result[2]:.2f} MB") + + # Print terminal table + print() + print(_render_terminal_table(old_result, new_result, path.name)) + + # Print GitHub-Flavored Markdown table + md = _render_markdown_table(old_result, new_result, path.name) + print("\n" + "=" * 60) + print("Copy the block below into your GitHub Pull Request:") + print("=" * 60) + print(md) + + +def main(argv: Optional[List[str]] = None) -> None: + parser = argparse.ArgumentParser( + description=( + "Before-vs-After benchmark: in-memory vs streaming Merkle tree " + "on a Wikipedia XML.bz2 dump." + ) + ) + parser.add_argument( + "file_path", + help="Path to the Wikipedia XML.bz2 dump file (e.g. simplewiki-20260201-....xml.bz2)", + ) + args = parser.parse_args(argv) + run_benchmark(args.file_path) + + +if __name__ == "__main__": + main() diff --git a/openverifiablellm/incremental_merkle.py b/openverifiablellm/incremental_merkle.py new file mode 100644 index 0000000..c181974 --- /dev/null +++ b/openverifiablellm/incremental_merkle.py @@ -0,0 +1,282 @@ +""" +incremental_merkle.py +===================== +An append-only, O(log N) space Merkle Tree using the "Merkle Frontier". + +Background: The Merkle Frontier Algorithm +------------------------------------------ +A classical Merkle tree requires ALL leaf hashes to be stored so the +tree can be reconstructed level by level. For N leaves that is O(N) +memory — completely unacceptable when N is in the millions. + +The Merkle Frontier (also called a "Merkle accumulator" in certificate +transparency literature) solves this by exploiting one structural +property of binary Merkle trees: + + The only nodes you ever need to recompute the root are the + **rightmost unpaired nodes at each depth** — the "frontier". + +Concretely: after inserting N leaves the frontier contains at most +⌈log₂ N⌉ hashes — one per set bit in the binary representation of N. + +Example with N = 5 (binary: 101) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Leaf indices: L0 L1 L2 L3 L4 + +Full tree structure: + ROOT + / \\ + H(0-3) H(4) + / \\ + H(0-1) H(2-3) + / \\ / \\ + L0 L1 L2 L3 + +After appending L0…L4 the frontier is: + depth 2 → H(0-3) (covers 4 leaves — a complete sub-tree) + depth 0 → L4 (a single unpaired leaf at the right edge) + +To get the root we just fold right: root = H( H(0-3) || L4 ) + +Space used by frontier: 2 hashes (⌈log₂ 5⌉ = 3, but only 2 bits are +set in 5 = 0b101, so there are only 2 frontier nodes). + +Root computation +---------------- +``get_root_hash()`` folds the frontier from the **lowest depth** to the +**highest depth**, combining nodes pair-wise: + + current = frontier[lowest_depth] + for d in range(lowest_depth + 1, max_depth + 1): + if frontier[d] exists: + current = sha256( frontier[d] || current ) + else: + current = sha256( current || current ) # duplicate (odd node) + +The final ``current`` value is the root hash. + +Usage +----- + from openverifiablellm.incremental_merkle import IncrementalMerkleTree + + tree = IncrementalMerkleTree() + for text in stream_text_from_xml("dump.xml.bz2"): + tree.append_leaf(text) + + root = tree.get_root_hash() + print(f"Merkle root: {root}") +""" + +import hashlib +from typing import Dict, Optional + + +# --------------------------------------------------------------------------- +# Module-level helpers +# --------------------------------------------------------------------------- + +def _sha256_bytes(data: bytes) -> bytes: + """Return the raw 32-byte SHA-256 digest of *data*.""" + return hashlib.sha256(data).digest() + + +def _combine(left: bytes, right: bytes) -> bytes: + """ + Combine two 32-byte node hashes into a parent hash. + + The canonical MediaWiki / Bitcoin-style combination: + parent = SHA256( left_bytes || right_bytes ) + + Parameters + ---------- + left, right: + Raw 32-byte digest values (NOT hex strings). + + Returns + ------- + bytes + 32-byte digest of the concatenation. + """ + return _sha256_bytes(left + right) + + +# --------------------------------------------------------------------------- +# IncrementalMerkleTree +# --------------------------------------------------------------------------- + +class IncrementalMerkleTree: + """ + An append-only Merkle tree with O(log N) time and space per operation. + + State + ----- + _frontier : Dict[int, bytes] + Maps ``depth`` → 32-byte hash. + Depth 0 = leaf level. Higher depth = closer to the root. + The frontier holds at most one node per depth; a node at depth d + represents a **complete** subtree of height d (covering 2**d leaves). + + _leaf_count : int + Total number of leaves appended so far. Used only for logging/info. + + Invariant + --------- + After appending N leaves, ``_frontier`` contains exactly the nodes + corresponding to the set bits in the binary representation of N. + For example N = 6 = 0b110 → frontier has nodes at depth 1 and depth 2. + """ + + def __init__(self) -> None: + # depth → 32-byte node hash. Only "complete" subtree roots are stored. + self._frontier: Dict[int, bytes] = {} + self._leaf_count: int = 0 + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def append_leaf(self, text_chunk: str) -> None: + """ + Hash *text_chunk* and insert it as the next leaf in the tree. + + Algorithm (O(log N) time, O(log N) space) + ------------------------------------------ + 1. Compute the SHA-256 hash of the UTF-8-encoded text. + 2. Start at depth 0 (leaf level) with the new hash as ``node``. + 3. While there is already a node stored at the current depth: + a. Combine the stored node (left) with our new node (right). + b. Remove the stored node from the frontier (the slot is now + "consumed" into a higher level). + c. Move one level up and continue with the combined hash. + 4. Store the final unconsumed node at its depth in the frontier. + + This mirrors how a binary counter increments: each carry bit + propagates up until it finds an empty slot. + + Parameters + ---------- + text_chunk: + Arbitrary Unicode string. Encoded to UTF-8 before hashing. + """ + # Step 1: hash the raw text to get a 32-byte leaf digest + new_node: bytes = _sha256_bytes(text_chunk.encode("utf-8")) + + # Step 2-4: propagate carries up the tree, exactly like binary addition + depth: int = 0 + while depth in self._frontier: + # Combine: existing left sibling || new right node → parent + left_sibling: bytes = self._frontier.pop(depth) + new_node = _combine(left_sibling, new_node) + depth += 1 + + # No existing node at this depth — park the new node here + self._frontier[depth] = new_node + self._leaf_count += 1 + + def get_root_hash(self) -> Optional[str]: + """ + Compute and return the current Merkle root hash as a hex string. + + This method is **non-destructive** — the frontier is not modified. + + Algorithm (O(log N)) + -------------------- + The frontier decomposes N leaves into complete power-of-two subtrees, + one per set bit of N. To collapse them into a single root we must + replicate the same "odd-node duplication" rule used by the classic + batch builder: + + When a level has an **odd** number of nodes, the rightmost node + is paired with itself: parent = combine(node, node). + + Concretely, the frontier nodes sit at various depths d₀ < d₁ < … < dₖ + (sorted ascending). We fold them right-to-left (lowest depth first), + promoting each partial subtree to the next depth by self-combining + before merging it with the larger complete subtree on its left: + + accumulator = frontier[d₀] + + for each successive depth dᵢ (i = 1 … k): + # Promote accumulator from its current depth up to dᵢ + # by repeatedly self-combining (mirroring the batch tree's + # odd-node duplication at each intermediate level). + while current_depth < dᵢ: + accumulator = combine(accumulator, accumulator) + current_depth += 1 + + # Merge: the complete subtree at dᵢ is on the LEFT + accumulator = combine(frontier[dᵢ], accumulator) + + Why self-combine? + ----------------- + In the batch tree, after all full pairs are consumed at level L, + any leftover (odd) node is duplicated before ascending. The frontier + encodes exactly those "leftover" nodes. If frontier[d₀] exists and + the next frontier node is at d₁ > d₀+1, the batch tree would have + duplicated the depth-d₀ subtree (d₁-d₀) times to produce a depth-d₁ + right child before combining with the depth-d₁ left sibling. + + Edge cases + ---------- + * Zero leaves → returns ``None``. + * Single leaf → returns SHA-256(leaf text). + * Power-of-two count → frontier has 1 node, returned directly. + + Returns + ------- + str or None + 64-character lowercase hex string, or ``None`` if empty. + """ + if not self._frontier: + return None + + # Sort depths ascending: smallest (rightmost partial) to largest (leftmost complete) + sorted_depths = sorted(self._frontier.keys()) + + # Seed the accumulator with the rightmost (lowest depth) frontier node + accumulator: bytes = self._frontier[sorted_depths[0]] + current_depth: int = sorted_depths[0] + + for target_depth in sorted_depths[1:]: + # ---------------------------------------------------------------- + # Promote accumulator to target_depth by self-combining. + # This mirrors the batch tree's "duplicate the odd node" rule: + # at each intermediate level the partial right-edge subtree is + # paired with itself before ascending one more level. + # ---------------------------------------------------------------- + while current_depth < target_depth: + accumulator = _combine(accumulator, accumulator) + current_depth += 1 + + # ---------------------------------------------------------------- + # Merge: the complete subtree in the frontier at target_depth + # sits to the LEFT of our right-edge accumulator. + # ---------------------------------------------------------------- + accumulator = _combine(self._frontier[target_depth], accumulator) + # After the merge current_depth advances by one more level + current_depth += 1 + + return accumulator.hex() + + # ------------------------------------------------------------------ + # Convenience / introspection + # ------------------------------------------------------------------ + + @property + def leaf_count(self) -> int: + """Total number of leaves that have been appended.""" + return self._leaf_count + + @property + def frontier_depth(self) -> int: + """Current maximum depth of the frontier (0 if empty).""" + return max(self._frontier.keys(), default=0) + + def __repr__(self) -> str: + return ( + f"IncrementalMerkleTree(" + f"leaves={self._leaf_count}, " + f"frontier_nodes={len(self._frontier)}, " + f"max_depth={self.frontier_depth}" + f")" + ) diff --git a/openverifiablellm/streaming_utils.py b/openverifiablellm/streaming_utils.py new file mode 100644 index 0000000..09ec2e3 --- /dev/null +++ b/openverifiablellm/streaming_utils.py @@ -0,0 +1,175 @@ +""" +streaming_utils.py +================== +Memory-efficient, streaming text extractor for Wikipedia XML dumps. + +Key design decisions +-------------------- +* Uses ``xml.etree.ElementTree.iterparse`` so the XML is parsed + **event-by-event** — no full DOM is ever built in RAM. +* After each ```` element is yielded it is immediately cleared + (``elem.clear()``), releasing both the element and all its children. + This keeps heap usage at O(1) regardless of dump size. +* Supports both plain ``.xml`` and bz2-compressed ``.xml.bz2`` inputs + by sniffing the first three bytes for the BZh magic header. +* The generator contract: callers receive one cleaned plain-text string + per Wikipedia article. An empty/redirect article yields nothing. + +Usage +----- + from openverifiablellm.streaming_utils import stream_text_from_xml + + for article_text in stream_text_from_xml("simplewiki-....xml.bz2"): + # process one article at a time — constant memory + do_something(article_text) +""" + +import bz2 +import gc +import logging +import xml.etree.ElementTree as ET +from pathlib import Path +from typing import Generator + +from openverifiablellm.utils import clean_wikitext + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Tag name suffixes we care about (we use .endswith() to be namespace-agnostic +# because MediaWiki dumps include a Clark-notation namespace prefix such as +# "{http://www.mediawiki.org/xml/export-0.11/}page"). +# --------------------------------------------------------------------------- +_PAGE_TAG_SUFFIX = "page" +_TEXT_TAG_SUFFIX = "text" + + +def _open_xml_source(file_path: Path): + """ + Return an open binary file-like object for ``file_path``. + + Sniffs the first 3 bytes for the BZh magic header used by bzip2. + Falls back to a plain binary open for uncompressed XML. + + Parameters + ---------- + file_path: + Resolved ``Path`` to the dump file. + + Returns + ------- + A context-manager-compatible binary IO object. + """ + with file_path.open("rb") as probe: + magic = probe.read(3) + + if magic == b"BZh": + logger.debug("Detected bzip2 stream: %s", file_path.name) + return bz2.open(file_path, "rb") + + logger.debug("Detected plain XML stream: %s", file_path.name) + return file_path.open("rb") + + +def stream_text_from_xml( + file_path: str, +) -> Generator[str, None, None]: + """ + Stream cleaned article texts from a Wikipedia XML (or XML.bz2) dump. + + This is a **generator** — it yields exactly one string per Wikipedia + article that contains non-empty wikitext. It never holds more than + a single ```` element tree in memory at any moment. + + Memory complexity : O(1) — independent of dump size. + Time complexity : O(N) — one linear scan of the byte stream. + + Parameters + ---------- + file_path: + Path to a Wikipedia XML dump. Both ``.xml`` and ``.xml.bz2`` + (bzip2-compressed) files are accepted. + + Yields + ------ + str + Cleaned plain-text content of one Wikipedia article. + Articles that are empty after cleaning are silently skipped. + + Raises + ------ + FileNotFoundError + If ``file_path`` does not exist on disk. + xml.etree.ElementTree.ParseError + If the XML stream is structurally malformed. + + Examples + -------- + >>> total = 0 + >>> for text in stream_text_from_xml("simplewiki-20260201.xml.bz2"): + ... total += len(text) + >>> print(f"Streamed {total:,} characters with O(1) memory") + """ + path = Path(file_path) + + if not path.exists(): + raise FileNotFoundError(f"Dump file not found: {path}") + + articles_yielded = 0 + + with _open_xml_source(path) as xml_stream: + # ``iterparse`` fires events as the SAX-like cursor advances. + # We only care about "end" events (element fully parsed). + context = ET.iterparse(xml_stream, events=("end",)) + + for _event, elem in context: + # ---------------------------------------------------------------- + # Tag matching: MediaWiki dumps include a Clark-notation namespace + # prefix, e.g. "{http://www.mediawiki.org/xml/export-0.11/}page". + # Using .endswith() avoids hard-coding any specific namespace URI + # while still being precise enough for our needs. + # ---------------------------------------------------------------- + if not elem.tag.endswith(_PAGE_TAG_SUFFIX): + continue + + # At this point *elem* is the fully-parsed subtree. + # Walk its children to locate the element. + raw_text: str = "" + for child in elem.iter(): + if child.tag.endswith(_TEXT_TAG_SUFFIX): + if child.text: + raw_text = child.text + break + + # ---------------------------------------------------------------- + # CRITICAL memory management step: + # elem.clear() removes all child elements, text content, and + # attributes from this element object, dropping every reference + # that iterparse has accumulated in the parsed subtree so far. + # Without this call, ALL previously seen elements remain + # live in memory for the entire lifetime of the loop — causing + # O(N) memory growth and eventual OOM on large dumps. + # ---------------------------------------------------------------- + elem.clear() + + # Periodically request the cyclic-reference collector so that + # any cross-references within now-cleared subtrees are resolved + # promptly rather than accumulating until the next automatic GC. + # We amortise the GC overhead by triggering only every 1 000 pages. + articles_yielded += 1 + if articles_yielded % 1_000 == 0: + gc.collect() + logger.debug("Streamed %d articles so far …", articles_yielded) + + if not raw_text: + continue + + cleaned = clean_wikitext(raw_text) + if cleaned: + yield cleaned + + logger.info( + "Finished streaming '%s': %d articles yielded.", + path.name, + articles_yielded, + ) diff --git a/tests/test_merkle.py b/tests/test_merkle.py new file mode 100644 index 0000000..9adcc2b --- /dev/null +++ b/tests/test_merkle.py @@ -0,0 +1,364 @@ +""" +test_merkle.py +============== +pytest suite for IncrementalMerkleTree. + +The critical correctness property we verify +-------------------------------------------- +Given the same ordered sequence of N strings, a *batch* Merkle tree +(built by collecting all leaf hashes up-front, then reducing level by +level) and our *incremental* Merkle tree (appending one leaf at a time +via the Merkle Frontier) MUST produce byte-for-byte identical root hashes. + +If that invariant ever breaks, the streaming pipeline cannot be trusted +as a drop-in replacement for the legacy in-memory approach. + +Run with: + pip install -e ".[dev]" + pytest tests/test_merkle.py -v +""" + +import hashlib +from typing import List, Optional + +import pytest + +from openverifiablellm.incremental_merkle import IncrementalMerkleTree + + +# =========================================================================== +# Reference implementation: a classic batch Merkle tree +# =========================================================================== + +def _sha256_bytes(data: bytes) -> bytes: + """Return raw 32-byte SHA-256 digest.""" + return hashlib.sha256(data).digest() + + +def _combine(left: bytes, right: bytes) -> bytes: + """Combine two 32-byte node hashes: parent = SHA-256(left || right).""" + return _sha256_bytes(left + right) + + +def batch_merkle_root(texts: List[str]) -> Optional[str]: + """ + Build a standard, batch Merkle tree from a list of strings and return + the root hash as a hex string. + + This is the canonical reference implementation used to verify the + incremental version. It stores ALL leaf hashes in memory and reduces + them level-by-level exactly as the legacy code does. + + Odd-number-of-nodes rule: when a level has an odd count, the last node + is duplicated so every parent has exactly two children. This matches + the rule used in ``IncrementalMerkleTree.get_root_hash()``. + + Parameters + ---------- + texts : list of str + Ordered list of article texts (or any strings). + + Returns + ------- + str | None + 64-char hex root hash, or ``None`` if *texts* is empty. + """ + if not texts: + return None + + # Leaf level: hash each string + level: List[bytes] = [ + _sha256_bytes(t.encode("utf-8")) for t in texts + ] + + # Reduce level-by-level until only the root remains + while len(level) > 1: + next_level: List[bytes] = [] + for i in range(0, len(level), 2): + left = level[i] + # Duplicate the last node if the level has an odd count + right = level[i + 1] if i + 1 < len(level) else left + next_level.append(_combine(left, right)) + level = next_level + + return level[0].hex() + + +# =========================================================================== +# Fixtures +# =========================================================================== + +@pytest.fixture +def hundred_strings() -> List[str]: + """ + A deterministic list of 100 unique strings. + + Uses the f-string ``"article_{i:03d}"`` pattern so the content is + predictable and reproducible across test runs. + """ + return [f"article_{i:03d}: The quick brown fox jumps over the lazy dog." for i in range(100)] + + +# =========================================================================== +# Core correctness test (the PRIMARY deliverable) +# =========================================================================== + +class TestIncrementalVsBatch: + """ + Verify that IncrementalMerkleTree produces the same root as the + batch reference implementation for the same input sequence. + """ + + def test_root_hash_matches_batch_100_strings( + self, hundred_strings: List[str] + ) -> None: + """ + PRIMARY TEST: IncrementalMerkleTree root must exactly equal the + batch Merkle root for the same 100 strings. + + This is the definitive correctness gate for the streaming pipeline. + """ + # Build batch reference root + expected_root = batch_merkle_root(hundred_strings) + assert expected_root is not None, "batch_merkle_root should not return None for non-empty input" + + # Build incremental root + tree = IncrementalMerkleTree() + for text in hundred_strings: + tree.append_leaf(text) + + actual_root = tree.get_root_hash() + assert actual_root is not None, "IncrementalMerkleTree.get_root_hash() must not return None" + + assert actual_root == expected_root, ( + f"Root hash mismatch!\n" + f" Batch root : {expected_root}\n" + f" Incremental root: {actual_root}\n" + "The streaming pipeline is NOT a safe replacement for the " + "legacy in-memory approach until this test passes." + ) + + def test_root_hash_matches_batch_single_string(self) -> None: + """Single leaf: root must equal SHA-256 of that leaf's text.""" + texts = ["only one article here"] + expected = batch_merkle_root(texts) + + tree = IncrementalMerkleTree() + tree.append_leaf(texts[0]) + + assert tree.get_root_hash() == expected + + def test_root_hash_matches_batch_two_strings(self) -> None: + """Two leaves: tests the first combine() call in both implementations.""" + texts = ["alpha", "beta"] + expected = batch_merkle_root(texts) + + tree = IncrementalMerkleTree() + for t in texts: + tree.append_leaf(t) + + assert tree.get_root_hash() == expected + + def test_root_hash_matches_batch_power_of_two(self) -> None: + """ + Power-of-two leaf count (8 leaves): the tree is perfectly balanced; + the frontier should collapse to a single node (the root itself). + """ + texts = [f"leaf_{i}" for i in range(8)] + expected = batch_merkle_root(texts) + + tree = IncrementalMerkleTree() + for t in texts: + tree.append_leaf(t) + + # For exactly 2^k leaves, the frontier should have exactly 1 node + assert len(tree._frontier) == 1, ( + "For a power-of-two leaf count the frontier should collapse to 1 node" + ) + assert tree.get_root_hash() == expected + + def test_root_hash_matches_batch_odd_leaf_count(self) -> None: + """Odd leaf count (7): tests the 'duplicate last node' codepath.""" + texts = [f"article_{i}" for i in range(7)] + expected = batch_merkle_root(texts) + + tree = IncrementalMerkleTree() + for t in texts: + tree.append_leaf(t) + + assert tree.get_root_hash() == expected + + @pytest.mark.parametrize("n", [1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 17, 31, 32, 33, 64, 100, 128, 255, 256]) + def test_root_hash_matches_batch_parametric(self, n: int) -> None: + """ + Parametric sweep over various leaf counts including edge cases, + powers of two, and powers-of-two ± 1. + """ + texts = [f"string_{i:04d}" for i in range(n)] + expected = batch_merkle_root(texts) + + tree = IncrementalMerkleTree() + for t in texts: + tree.append_leaf(t) + + assert tree.get_root_hash() == expected, ( + f"Root hash mismatch for n={n}" + ) + + +# =========================================================================== +# Empty-tree behaviour +# =========================================================================== + +class TestEmptyTree: + def test_get_root_hash_returns_none_when_empty(self) -> None: + """An empty tree has no root — get_root_hash() must return None.""" + tree = IncrementalMerkleTree() + assert tree.get_root_hash() is None + + def test_leaf_count_zero_when_empty(self) -> None: + tree = IncrementalMerkleTree() + assert tree.leaf_count == 0 + + def test_frontier_empty_when_empty(self) -> None: + tree = IncrementalMerkleTree() + assert tree._frontier == {} + + +# =========================================================================== +# Leaf count tracking +# =========================================================================== + +class TestLeafCount: + def test_leaf_count_increments(self) -> None: + tree = IncrementalMerkleTree() + for i in range(50): + tree.append_leaf(f"leaf_{i}") + assert tree.leaf_count == i + 1 + + def test_leaf_count_matches_hundred(self, hundred_strings: List[str]) -> None: + tree = IncrementalMerkleTree() + for t in hundred_strings: + tree.append_leaf(t) + assert tree.leaf_count == 100 + + +# =========================================================================== +# Frontier size invariant +# =========================================================================== + +class TestFrontierInvariant: + """ + The frontier size equals the number of set bits in the binary + representation of the leaf count (popcount / Hamming weight). + """ + + @pytest.mark.parametrize("n", [1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 100, 128, 255, 256]) + def test_frontier_size_equals_popcount(self, n: int) -> None: + tree = IncrementalMerkleTree() + for i in range(n): + tree.append_leaf(f"x_{i}") + + expected_frontier_nodes = bin(n).count("1") + assert len(tree._frontier) == expected_frontier_nodes, ( + f"After {n} leaves (binary: {bin(n)}), frontier should have " + f"{expected_frontier_nodes} node(s), got {len(tree._frontier)}" + ) + + +# =========================================================================== +# Determinism / reproducibility +# =========================================================================== + +class TestDeterminism: + def test_same_input_same_root(self, hundred_strings: List[str]) -> None: + """Two trees built from the same input must produce identical roots.""" + tree1 = IncrementalMerkleTree() + tree2 = IncrementalMerkleTree() + for t in hundred_strings: + tree1.append_leaf(t) + tree2.append_leaf(t) + assert tree1.get_root_hash() == tree2.get_root_hash() + + def test_different_order_different_root(self) -> None: + """Order matters: reversed input must produce a different root.""" + texts = [f"item_{i}" for i in range(10)] + + tree_fwd = IncrementalMerkleTree() + tree_rev = IncrementalMerkleTree() + for t in texts: + tree_fwd.append_leaf(t) + for t in reversed(texts): + tree_rev.append_leaf(t) + + assert tree_fwd.get_root_hash() != tree_rev.get_root_hash() + + def test_extra_leaf_changes_root(self) -> None: + """Appending one more leaf must change the root hash.""" + texts = [f"article_{i}" for i in range(10)] + + tree_a = IncrementalMerkleTree() + for t in texts: + tree_a.append_leaf(t) + root_a = tree_a.get_root_hash() + + tree_b = IncrementalMerkleTree() + for t in texts: + tree_b.append_leaf(t) + tree_b.append_leaf("one_more_article") + root_b = tree_b.get_root_hash() + + assert root_a != root_b + + def test_get_root_hash_is_non_destructive(self) -> None: + """Calling get_root_hash() multiple times must return the same value.""" + tree = IncrementalMerkleTree() + for i in range(20): + tree.append_leaf(f"leaf_{i}") + + roots = {tree.get_root_hash() for _ in range(5)} + assert len(roots) == 1, "get_root_hash() must be idempotent" + + +# =========================================================================== +# Hash format sanity checks +# =========================================================================== + +class TestHashFormat: + def test_root_hash_is_64_char_hex(self) -> None: + """SHA-256 produces 32 bytes → 64 lowercase hex characters.""" + tree = IncrementalMerkleTree() + tree.append_leaf("hello world") + root = tree.get_root_hash() + assert root is not None + assert len(root) == 64 + assert root == root.lower() + # Must be valid hex + int(root, 16) + + def test_single_leaf_root_equals_sha256_of_text(self) -> None: + """ + For a single leaf the root hash must equal SHA-256(text.encode()). + There is no combining step — the leaf hash IS the root. + """ + text = "wikipedia article about Python" + expected = hashlib.sha256(text.encode("utf-8")).hexdigest() + + tree = IncrementalMerkleTree() + tree.append_leaf(text) + + assert tree.get_root_hash() == expected + + +# =========================================================================== +# repr() smoke test +# =========================================================================== + +class TestRepr: + def test_repr_contains_leaf_count(self) -> None: + tree = IncrementalMerkleTree() + for i in range(7): + tree.append_leaf(f"t_{i}") + r = repr(tree) + assert "leaves=7" in r + assert "IncrementalMerkleTree" in r From d644c518b4fd442dc8a1741e0cff560a4ff57048 Mon Sep 17 00:00:00 2001 From: Muneerali199 Date: Mon, 9 Mar 2026 14:10:13 +0530 Subject: [PATCH 2/7] feat: add streaming support to utils.py and scripts/benchmark.py - Add four streaming functions to utils.py: stream_chunks (yields raw byte chunks from a file), stream_sha256 (yields rolling hash pairs), compute_merkle_root_streaming (builds Merkle root from any bytes iterable), and stream_extract_text_from_xml (O(1)-memory generator over Wikipedia XML dumps) - Add scripts/benchmark.py as a thin CLI launcher that delegates to openverifiablellm.benchmark.main, keeping the scripts/ directory as plain entry-points with no duplicated logic --- openverifiablellm/utils.py | 233 ++++++++++++++++++++++++++++++++++++- scripts/benchmark.py | 49 ++++++++ 2 files changed, 281 insertions(+), 1 deletion(-) create mode 100644 scripts/benchmark.py diff --git a/openverifiablellm/utils.py b/openverifiablellm/utils.py index 13ea81b..47f0654 100644 --- a/openverifiablellm/utils.py +++ b/openverifiablellm/utils.py @@ -7,7 +7,7 @@ import logging import json import platform -from typing import Union, Optional, Dict, Any, List, Tuple +from typing import Union, Optional, Dict, Any, List, Tuple, Generator, Iterable logger = logging.getLogger(__name__) MERKLE_CHUNK_SIZE_BYTES = 1024 * 1024 # 1MB @@ -374,6 +374,237 @@ def clean_wikitext(text: str) -> str: text = RE_WHITESPACE.sub(" ", text) return text.strip() +# --------------------------------------------------------------------------- +# Streaming / generator-based variants +# --------------------------------------------------------------------------- + +def stream_chunks( + file_path: Union[str, Path], + chunk_size: int = MERKLE_CHUNK_SIZE_BYTES, +) -> Generator[bytes, None, None]: + """ + Yield successive raw byte chunks from *file_path* without loading the + entire file into memory. + + This is the streaming analogue to the chunk-reading loop inside + ``compute_merkle_root``. Callers can process each chunk on-the-fly + (e.g. hash it, write it somewhere) without ever holding more than one + chunk in RAM at a time. + + Parameters + ---------- + file_path: + Path to any binary file (plain or bz2-compressed). + chunk_size: + Number of bytes per chunk. Must be a positive integer. + Defaults to ``MERKLE_CHUNK_SIZE_BYTES`` (1 MB). + + Yields + ------ + bytes + Raw byte chunk of at most *chunk_size* bytes. + The final chunk may be shorter if the file size is not a multiple + of *chunk_size*. + + Raises + ------ + ValueError + If *chunk_size* is not a positive integer. + FileNotFoundError + If *file_path* does not exist. + + Examples + -------- + >>> for chunk in stream_chunks("data/raw/simplewiki.xml.bz2"): + ... process(chunk) + """ + if chunk_size <= 0: + raise ValueError("chunk_size must be a positive integer") + + path = Path(file_path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + with path.open("rb") as f: + while True: + chunk = f.read(chunk_size) + if not chunk: + break + yield chunk + + +def stream_sha256( + file_path: Union[str, Path], + chunk_size: int = 8192, +) -> Generator[Tuple[bytes, str], None, None]: + """ + Stream SHA-256 hashes of successive chunks from *file_path*. + + Unlike ``compute_sha256``, which returns a **single** hash over the + entire file, this generator yields ``(chunk_bytes, partial_hex)`` + pairs as each chunk is read. After the generator is exhausted the + last ``partial_hex`` value equals the SHA-256 of the whole file. + + Parameters + ---------- + file_path: + Path to the file to hash. + chunk_size: + Read buffer size in bytes (default 8 192). + + Yields + ------ + (chunk_bytes, running_hex) : Tuple[bytes, str] + *chunk_bytes* — the raw bytes just read. + *running_hex* — SHA-256 hex digest of **all bytes read so far** + (i.e. the rolling hash after absorbing *chunk_bytes*). + + Examples + -------- + >>> *_, (_, final_hash) = stream_sha256("data/raw/dump.xml.bz2") + >>> print(final_hash) # same value as compute_sha256(file_path=...) + """ + if chunk_size <= 0: + raise ValueError("chunk_size must be a positive integer") + + path = Path(file_path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + sha256 = hashlib.sha256() + with path.open("rb") as f: + while True: + chunk = f.read(chunk_size) + if not chunk: + break + sha256.update(chunk) + yield chunk, sha256.hexdigest() + + +def compute_merkle_root_streaming( + chunks: Iterable[bytes], +) -> str: + """ + Compute a Merkle root over an **arbitrary iterable of byte chunks** + without requiring a seekable file on disk. + + This is the streaming counterpart to ``compute_merkle_root``. It + accepts any iterable — a ``stream_chunks()`` generator, a network + socket, a list of pre-computed blobs, etc. — so callers are not + restricted to file-backed data. + + The leaf hashes and tree construction follow the exact same algorithm + as ``compute_merkle_root`` to ensure identical roots for identical + content. + + Parameters + ---------- + chunks: + Any iterable that yields ``bytes`` objects. Each object is + treated as one Merkle leaf. + + Returns + ------- + str + 64-character lowercase hex SHA-256 Merkle root. + Returns ``compute_sha256(data=b"")`` for an empty iterable. + + Examples + -------- + >>> root = compute_merkle_root_streaming(stream_chunks("dump.xml.bz2")) + >>> assert root == compute_merkle_root("dump.xml.bz2") + """ + leaves: List[bytes] = [] + + for chunk in chunks: + leaf_hex = compute_sha256(data=chunk) + leaves.append(bytes.fromhex(leaf_hex)) + + if not leaves: + return compute_sha256(data=b"") + + while len(leaves) > 1: + next_level: List[bytes] = [] + for i in range(0, len(leaves), 2): + left = leaves[i] + right = leaves[i + 1] if i + 1 < len(leaves) else left + combined = left + right + parent_hex = compute_sha256(data=combined) + next_level.append(bytes.fromhex(parent_hex)) + leaves = next_level + + return leaves[0].hex() + + +def stream_extract_text_from_xml( + input_path: Union[str, Path], +) -> Generator[str, None, None]: + """ + Stream cleaned article texts from a Wikipedia XML dump without writing + any output file. + + This is the generator-based (streaming) counterpart to + ``extract_text_from_xml``. It yields one cleaned plain-text string + per Wikipedia article, keeping heap usage at O(1) regardless of dump + size. + + Supports both plain ``.xml`` and bzip2-compressed ``.xml.bz2`` inputs + by sniffing the first three bytes for the BZh magic header — exactly + the same auto-detection logic used in ``extract_text_from_xml``. + + Parameters + ---------- + input_path: + Path to a Wikipedia XML dump (compressed or uncompressed). + + Yields + ------ + str + Cleaned plain-text content of one Wikipedia article. + Articles that are empty after cleaning are silently skipped. + + Raises + ------ + FileNotFoundError + If *input_path* does not exist on disk. + ET.ParseError + If the XML stream is structurally malformed. + + Examples + -------- + >>> total_chars = sum(len(t) for t in stream_extract_text_from_xml("dump.xml.bz2")) + >>> print(f"Total characters: {total_chars:,}") + """ + input_path = Path(input_path) + + if not input_path.exists(): + raise FileNotFoundError(f"Dump file not found: {input_path}") + + with open(input_path, "rb") as probe: + is_bz2 = probe.read(3) == b"BZh" + + open_func = bz2.open if is_bz2 else open + + with open_func(input_path, "rb") as f: + context = ET.iterparse(f, events=("end",)) + for _, elem in context: + if elem.tag.endswith("page"): + text_elem = elem.find(".//{*}text") + raw_text: str = "" + if text_elem is not None and text_elem.text: + raw_text = text_elem.text + elem.clear() + + if not raw_text: + continue + + cleaned = clean_wikitext(raw_text) + if cleaned: + yield cleaned + + logger.info("Finished streaming articles from '%s'.", input_path.name) + + if __name__ == "__main__": if len(sys.argv) < 2: print("Usage: python -m openverifiablellm.utils ") diff --git a/scripts/benchmark.py b/scripts/benchmark.py new file mode 100644 index 0000000..49ef4c1 --- /dev/null +++ b/scripts/benchmark.py @@ -0,0 +1,49 @@ +""" +benchmark.py (scripts/) +======================== +CLI entry-point for the before-vs-after benchmark. + +This thin wrapper delegates all logic to +``openverifiablellm.benchmark.main``, keeping the scripts/ directory as +a collection of plain launchers with no duplicated implementation. + +Usage +----- + # Download the dump first (≈350 MB): + python scripts/download_dump.py --wiki simplewiki --date 20260201 + + # Then run the benchmark: + python scripts/benchmark.py simplewiki-20260201-pages-articles-multistream.xml.bz2 + + # Or, equivalently, via the package module: + python -m openverifiablellm.benchmark + +What it measures +---------------- +* **Old Way** (in-memory): decompress everything, collect all article texts + in a list, build a batch Merkle tree — O(N) RAM. +* **New Way** (streaming): yield one article at a time with + ``stream_text_from_xml``, feed each into an ``IncrementalMerkleTree`` + — O(log N) RAM. + +The script prints: +1. A terminal box-table with wall-clock time and peak RAM side-by-side. +2. A GitHub-Flavored Markdown table you can paste straight into the PR. +""" + +import sys +from pathlib import Path + +# --------------------------------------------------------------------------- +# Make sure the project root is on sys.path when the script is run directly +# (i.e. ``python scripts/benchmark.py``), even without an editable install. +# --------------------------------------------------------------------------- +_SCRIPTS_DIR = Path(__file__).resolve().parent +_PROJECT_ROOT = _SCRIPTS_DIR.parent +if str(_PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(_PROJECT_ROOT)) + +from openverifiablellm.benchmark import main # noqa: E402 + +if __name__ == "__main__": + main() From edc24798dd98aeb2ca915c5dab8653a465e3f0e9 Mon Sep 17 00:00:00 2001 From: Muneerali199 Date: Mon, 9 Mar 2026 14:38:08 +0530 Subject: [PATCH 3/7] refactor(utils): replace compute_sha256(data=) internal calls with _sha256_hex helper; remove streaming_utils.py - All internal Merkle tree callers (compute_merkle_root, generate_merkle_proof, verify_merkle_proof) now use the private _sha256_hex() helper instead of the public compute_sha256() API, eliminating Union return-type checker errors. - Removed unused 'cast' from typing imports. - Removed stray comment left over from an earlier edit. - Deleted openverifiablellm/streaming_utils.py: its streaming logic is fully covered by extract_text_from_xml(stream=True) in utils.py, per mentor feedback to update existing functions rather than adding standalone streaming functions. --- openverifiablellm/streaming_utils.py | 175 ----------- openverifiablellm/utils.py | 449 +++++++++------------------ 2 files changed, 143 insertions(+), 481 deletions(-) delete mode 100644 openverifiablellm/streaming_utils.py diff --git a/openverifiablellm/streaming_utils.py b/openverifiablellm/streaming_utils.py deleted file mode 100644 index 09ec2e3..0000000 --- a/openverifiablellm/streaming_utils.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -streaming_utils.py -================== -Memory-efficient, streaming text extractor for Wikipedia XML dumps. - -Key design decisions --------------------- -* Uses ``xml.etree.ElementTree.iterparse`` so the XML is parsed - **event-by-event** — no full DOM is ever built in RAM. -* After each ```` element is yielded it is immediately cleared - (``elem.clear()``), releasing both the element and all its children. - This keeps heap usage at O(1) regardless of dump size. -* Supports both plain ``.xml`` and bz2-compressed ``.xml.bz2`` inputs - by sniffing the first three bytes for the BZh magic header. -* The generator contract: callers receive one cleaned plain-text string - per Wikipedia article. An empty/redirect article yields nothing. - -Usage ------ - from openverifiablellm.streaming_utils import stream_text_from_xml - - for article_text in stream_text_from_xml("simplewiki-....xml.bz2"): - # process one article at a time — constant memory - do_something(article_text) -""" - -import bz2 -import gc -import logging -import xml.etree.ElementTree as ET -from pathlib import Path -from typing import Generator - -from openverifiablellm.utils import clean_wikitext - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Tag name suffixes we care about (we use .endswith() to be namespace-agnostic -# because MediaWiki dumps include a Clark-notation namespace prefix such as -# "{http://www.mediawiki.org/xml/export-0.11/}page"). -# --------------------------------------------------------------------------- -_PAGE_TAG_SUFFIX = "page" -_TEXT_TAG_SUFFIX = "text" - - -def _open_xml_source(file_path: Path): - """ - Return an open binary file-like object for ``file_path``. - - Sniffs the first 3 bytes for the BZh magic header used by bzip2. - Falls back to a plain binary open for uncompressed XML. - - Parameters - ---------- - file_path: - Resolved ``Path`` to the dump file. - - Returns - ------- - A context-manager-compatible binary IO object. - """ - with file_path.open("rb") as probe: - magic = probe.read(3) - - if magic == b"BZh": - logger.debug("Detected bzip2 stream: %s", file_path.name) - return bz2.open(file_path, "rb") - - logger.debug("Detected plain XML stream: %s", file_path.name) - return file_path.open("rb") - - -def stream_text_from_xml( - file_path: str, -) -> Generator[str, None, None]: - """ - Stream cleaned article texts from a Wikipedia XML (or XML.bz2) dump. - - This is a **generator** — it yields exactly one string per Wikipedia - article that contains non-empty wikitext. It never holds more than - a single ```` element tree in memory at any moment. - - Memory complexity : O(1) — independent of dump size. - Time complexity : O(N) — one linear scan of the byte stream. - - Parameters - ---------- - file_path: - Path to a Wikipedia XML dump. Both ``.xml`` and ``.xml.bz2`` - (bzip2-compressed) files are accepted. - - Yields - ------ - str - Cleaned plain-text content of one Wikipedia article. - Articles that are empty after cleaning are silently skipped. - - Raises - ------ - FileNotFoundError - If ``file_path`` does not exist on disk. - xml.etree.ElementTree.ParseError - If the XML stream is structurally malformed. - - Examples - -------- - >>> total = 0 - >>> for text in stream_text_from_xml("simplewiki-20260201.xml.bz2"): - ... total += len(text) - >>> print(f"Streamed {total:,} characters with O(1) memory") - """ - path = Path(file_path) - - if not path.exists(): - raise FileNotFoundError(f"Dump file not found: {path}") - - articles_yielded = 0 - - with _open_xml_source(path) as xml_stream: - # ``iterparse`` fires events as the SAX-like cursor advances. - # We only care about "end" events (element fully parsed). - context = ET.iterparse(xml_stream, events=("end",)) - - for _event, elem in context: - # ---------------------------------------------------------------- - # Tag matching: MediaWiki dumps include a Clark-notation namespace - # prefix, e.g. "{http://www.mediawiki.org/xml/export-0.11/}page". - # Using .endswith() avoids hard-coding any specific namespace URI - # while still being precise enough for our needs. - # ---------------------------------------------------------------- - if not elem.tag.endswith(_PAGE_TAG_SUFFIX): - continue - - # At this point *elem* is the fully-parsed subtree. - # Walk its children to locate the element. - raw_text: str = "" - for child in elem.iter(): - if child.tag.endswith(_TEXT_TAG_SUFFIX): - if child.text: - raw_text = child.text - break - - # ---------------------------------------------------------------- - # CRITICAL memory management step: - # elem.clear() removes all child elements, text content, and - # attributes from this element object, dropping every reference - # that iterparse has accumulated in the parsed subtree so far. - # Without this call, ALL previously seen elements remain - # live in memory for the entire lifetime of the loop — causing - # O(N) memory growth and eventual OOM on large dumps. - # ---------------------------------------------------------------- - elem.clear() - - # Periodically request the cyclic-reference collector so that - # any cross-references within now-cleared subtrees are resolved - # promptly rather than accumulating until the next automatic GC. - # We amortise the GC overhead by triggering only every 1 000 pages. - articles_yielded += 1 - if articles_yielded % 1_000 == 0: - gc.collect() - logger.debug("Streamed %d articles so far …", articles_yielded) - - if not raw_text: - continue - - cleaned = clean_wikitext(raw_text) - if cleaned: - yield cleaned - - logger.info( - "Finished streaming '%s': %d articles yielded.", - path.name, - articles_yielded, - ) diff --git a/openverifiablellm/utils.py b/openverifiablellm/utils.py index 59b0971..ff34df1 100644 --- a/openverifiablellm/utils.py +++ b/openverifiablellm/utils.py @@ -7,7 +7,7 @@ import logging import json import platform -from typing import Union, Optional, Dict, Any, List, Tuple, Generator, Iterable +from typing import Union, Optional, Dict, Any, List, Tuple, Generator, Iterable, Iterator from openverifiablellm.environment import generate_environment_fingerprint logger = logging.getLogger(__name__) @@ -21,60 +21,93 @@ RE_LINK = re.compile(r"\[\[(.*?)\]\]") RE_WHITESPACE = re.compile(r"\s+") +def _sha256_hex(data: bytes) -> str: + """Internal helper: return SHA-256 hex digest of raw bytes.""" + return hashlib.sha256(data).hexdigest() + # Merkle Tree Chunk-Level Hashing for Large Files -def compute_merkle_root(file_path: Union[str, Path], chunk_size: int = MERKLE_CHUNK_SIZE_BYTES) -> str: +def compute_merkle_root( + file_path: Union[str, Path, None] = None, + chunk_size: int = MERKLE_CHUNK_SIZE_BYTES, + *, + chunks: Optional[Iterable[bytes]] = None, +) -> str: + """ + Compute a Merkle root from a file path or an arbitrary iterable of byte chunks. + + Supports two modes: + - File mode (default): reads *file_path* in *chunk_size* chunks from disk. + - Streaming mode: pass ``chunks=`` to consume any byte iterable + (generator, network stream, list …) without touching the filesystem. + + Exactly one of *file_path* or *chunks* must be provided. + """ + if (file_path is None) == (chunks is None): + raise ValueError("Exactly one of 'file_path' or 'chunks' must be provided.") + if chunk_size <= 0: raise ValueError("chunk_size must be a positive integer") - path = Path(file_path) - leaves = [] + def _iter_chunks() -> Iterator[bytes]: + if chunks is not None: + yield from chunks + else: + path = Path(file_path) # type: ignore[arg-type] + with path.open("rb") as f: + while chunk := f.read(chunk_size): + yield chunk - with path.open("rb") as f: - while chunk := f.read(chunk_size): - # reuse compute_sha256 - leaf_hex = compute_sha256(data=chunk) - leaves.append(bytes.fromhex(leaf_hex)) + leaves: List[bytes] = [] + for chunk in _iter_chunks(): + leaves.append(bytes.fromhex(_sha256_hex(chunk))) if not leaves: - return compute_sha256(data=b"") + return _sha256_hex(b"") while len(leaves) > 1: - next_level = [] + next_level: List[bytes] = [] for i in range(0, len(leaves), 2): left = leaves[i] right = leaves[i + 1] if i + 1 < len(leaves) else left - combined = left + right - parent_hex = compute_sha256(data=combined) - next_level.append(bytes.fromhex(parent_hex)) - + next_level.append(bytes.fromhex(_sha256_hex(combined))) leaves = next_level return leaves[0].hex() def generate_merkle_proof( - file_path: Union[str, Path], - chunk_index: int, - chunk_size: int = MERKLE_CHUNK_SIZE_BYTES -): + file_path: Union[str, Path, None] = None, + chunk_index: int = 0, + chunk_size: int = MERKLE_CHUNK_SIZE_BYTES, + *, + chunks: Optional[Iterable[bytes]] = None, +) -> List[Tuple[str, bool]]: """ Generate Merkle proof for a specific chunk index. + Supports two modes: + - File mode (default): reads *file_path* in *chunk_size* chunks. + - Streaming mode: pass ``chunks=`` to consume any byte iterable. + Returns: List of tuples (sibling_hash_hex, is_left) """ - path = Path(file_path) + if (file_path is None) == (chunks is None): + raise ValueError("Exactly one of 'file_path' or 'chunks' must be provided.") if chunk_size <= 0: raise ValueError("chunk_size must be a positive integer") - leaves = [] + leaves: List[bytes] = [] - # Build leaf level - with path.open("rb") as f: - while chunk := f.read(chunk_size): - leaf_hex = compute_sha256(data=chunk) - leaves.append(bytes.fromhex(leaf_hex)) + if chunks is not None: + for chunk in chunks: + leaves.append(bytes.fromhex(_sha256_hex(chunk))) + else: + path = Path(file_path) # type: ignore[arg-type] + with path.open("rb") as f: + while chunk := f.read(chunk_size): + leaves.append(bytes.fromhex(_sha256_hex(chunk))) if not leaves: raise ValueError("Cannot generate proof for empty file") @@ -82,11 +115,10 @@ def generate_merkle_proof( if chunk_index < 0 or chunk_index >= len(leaves): raise IndexError("Chunk index out of range") - proof = [] + proof: List[Tuple[str, bool]] = [] index = chunk_index while len(leaves) > 1: - # If odd number of nodes, duplicate last if len(leaves) % 2 == 1: leaves.append(leaves[-1]) @@ -96,12 +128,10 @@ def generate_merkle_proof( is_left = sibling_index < index proof.append((sibling.hex(), is_left)) - # Build next level - next_level = [] + next_level: List[bytes] = [] for i in range(0, len(leaves), 2): combined = leaves[i] + leaves[i + 1] - parent_hex = compute_sha256(data=combined) - next_level.append(bytes.fromhex(parent_hex)) + next_level.append(bytes.fromhex(_sha256_hex(combined))) index //= 2 leaves = next_level @@ -117,7 +147,7 @@ def verify_merkle_proof( Verify a Merkle proof for given chunk bytes. """ try: - current_hash = bytes.fromhex(compute_sha256(data=chunk_bytes)) + current_hash = bytes.fromhex(_sha256_hex(chunk_bytes)) expected_root = bytes.fromhex(merkle_root) except (TypeError, ValueError): return False @@ -148,64 +178,89 @@ def verify_merkle_proof( else: combined = current_hash + sibling - parent_hex = compute_sha256(data=combined) + parent_hex = _sha256_hex(combined) current_hash = bytes.fromhex(parent_hex) return current_hash == expected_root # extract clean wikipage from actual wikipage -def extract_text_from_xml(input_path): +def extract_text_from_xml( + input_path: Union[str, Path], + stream: bool = False, +) -> Optional[Generator[str, None, None]]: """ Process a Wikipedia XML dump (compressed or uncompressed) into cleaned plain text. - Each element is parsed, its revision text is extracted, - cleaned using `clean_wikitext()`, and appended to a single - output text file. + Supports two modes controlled by the *stream* flag: + + - **Batch mode** (``stream=False``, default): writes all cleaned article + texts to ``data/processed/wiki_clean.txt`` and generates the manifest. + Returns ``None``. - The processed output is saved to: - data/processed/wiki_clean.txt + - **Streaming mode** (``stream=True``): returns a generator that yields + one cleaned plain-text string per Wikipedia article with O(1) memory + usage. No file is written and no manifest is generated. Parameters ---------- input_path : str or Path - Path to the Wikipedia XML dump file. + Path to the Wikipedia XML dump file (plain or bz2-compressed). + stream : bool + If True, return a generator instead of writing to disk. - Output - ------ - Creates: - data/processed/wiki_clean.txt + Returns + ------- + None or Generator[str, None, None] + ``None`` in batch mode; a text generator in streaming mode. """ input_path = Path(input_path) - # Fixed output path - project_root = Path.cwd() - output_dir = project_root / "data" / "processed" - output_dir.mkdir(parents=True, exist_ok=True) - - output_path = output_dir / "wiki_clean.txt" - - # Auto-detect file type using magic bytes separation with open(input_path, "rb") as test_f: is_bz2 = test_f.read(3) == b"BZh" open_func = bz2.open if is_bz2 else open + if stream: + def _generator() -> Generator[str, None, None]: + with open_func(input_path, "rb") as f: + context = ET.iterparse(f, events=("end",)) + for _, elem in context: + if elem.tag.endswith("page"): + text_elem = elem.find(".//{*}text") + raw_text: str = "" + if text_elem is not None and text_elem.text: + raw_text = text_elem.text + elem.clear() + if not raw_text: + continue + cleaned = clean_wikitext(raw_text) + if cleaned: + yield cleaned + logger.info("Finished streaming articles from '%s'.", input_path.name) + + return _generator() + + # Batch mode — write to file + project_root = Path.cwd() + output_dir = project_root / "data" / "processed" + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / "wiki_clean.txt" + with open_func(input_path, "rb") as f: context = ET.iterparse(f, events=("end",)) - with open(output_path, "w", encoding="utf-8") as out: for _, elem in context: if elem.tag.endswith("page"): text_elem = elem.find(".//{*}text") - if text_elem is not None and text_elem.text: cleaned = clean_wikitext(text_elem.text) if cleaned: out.write(cleaned + "\n\n") - elem.clear() + logger.info("Preprocessing complete. Output saved to %s", output_path) - generate_manifest(input_path,output_path) + generate_manifest(input_path, output_path) + return None # generate data manifest def generate_manifest(raw_path, processed_path): @@ -286,9 +341,6 @@ def load_merkle_proof( with proof_path.open("r", encoding="utf-8") as f: return json.load(f) - -# Content before line 270 remains unchanged -# Entire function definition from lines 270-314 should be deleted def verify_merkle_proof_from_file( proof_file_path: Union[str, Path], chunk_data: bytes, @@ -321,33 +373,49 @@ def compute_sha256( *, data: Optional[Union[bytes, bytearray]] = None, file_path: Optional[Union[str, Path]] = None, -) -> str: + stream: bool = False, +) -> Union[str, Generator[Tuple[bytes, str], None, None]]: """ - Compute SHA256 hash of a file OR raw bytes. - - This is used for both raw and processed files to ensure integrity. - This provides a deterministic fingerprint of the dataset, - enabling reproducibility and verification. - - Exactly one of `data` or `file_path` must be provided. + Compute SHA256 hash of a file OR raw bytes, with optional streaming support. + + Modes + ----- + - **data** mode: hash raw bytes in memory, return hex string. + - **file_path** mode (``stream=False``, default): hash the entire file, + return hex string. + - **file_path** mode (``stream=True``): return a generator that yields + ``(chunk_bytes, running_hex)`` pairs. The final ``running_hex`` equals + the SHA-256 of the whole file — same value as ``stream=False``. + + Exactly one of ``data`` or ``file_path`` must be provided. + ``stream=True`` is only valid with ``file_path``. """ - if (data is None) == (file_path is None): - raise ValueError( - "Exactly one of 'data' or 'file_path' must be provided." - ) + raise ValueError("Exactly one of 'data' or 'file_path' must be provided.") - sha256 = hashlib.sha256() + if stream and data is not None: + raise ValueError("stream=True is only valid with file_path, not data.") if data is not None: + sha256 = hashlib.sha256() sha256.update(data) return sha256.hexdigest() - path = Path(file_path) + path = Path(file_path) # type: ignore[arg-type] + + if stream: + def _stream_gen() -> Generator[Tuple[bytes, str], None, None]: + _sha256 = hashlib.sha256() + with path.open("rb") as f: + while _chunk := f.read(8192): + _sha256.update(_chunk) + yield _chunk, _sha256.hexdigest() + return _stream_gen() + + sha256 = hashlib.sha256() with path.open("rb") as f: while chunk := f.read(8192): sha256.update(chunk) - return sha256.hexdigest() def extract_dump_date(filename: str): @@ -380,237 +448,6 @@ def clean_wikitext(text: str) -> str: text = RE_WHITESPACE.sub(" ", text) return text.strip() -# --------------------------------------------------------------------------- -# Streaming / generator-based variants -# --------------------------------------------------------------------------- - -def stream_chunks( - file_path: Union[str, Path], - chunk_size: int = MERKLE_CHUNK_SIZE_BYTES, -) -> Generator[bytes, None, None]: - """ - Yield successive raw byte chunks from *file_path* without loading the - entire file into memory. - - This is the streaming analogue to the chunk-reading loop inside - ``compute_merkle_root``. Callers can process each chunk on-the-fly - (e.g. hash it, write it somewhere) without ever holding more than one - chunk in RAM at a time. - - Parameters - ---------- - file_path: - Path to any binary file (plain or bz2-compressed). - chunk_size: - Number of bytes per chunk. Must be a positive integer. - Defaults to ``MERKLE_CHUNK_SIZE_BYTES`` (1 MB). - - Yields - ------ - bytes - Raw byte chunk of at most *chunk_size* bytes. - The final chunk may be shorter if the file size is not a multiple - of *chunk_size*. - - Raises - ------ - ValueError - If *chunk_size* is not a positive integer. - FileNotFoundError - If *file_path* does not exist. - - Examples - -------- - >>> for chunk in stream_chunks("data/raw/simplewiki.xml.bz2"): - ... process(chunk) - """ - if chunk_size <= 0: - raise ValueError("chunk_size must be a positive integer") - - path = Path(file_path) - if not path.exists(): - raise FileNotFoundError(f"File not found: {path}") - - with path.open("rb") as f: - while True: - chunk = f.read(chunk_size) - if not chunk: - break - yield chunk - - -def stream_sha256( - file_path: Union[str, Path], - chunk_size: int = 8192, -) -> Generator[Tuple[bytes, str], None, None]: - """ - Stream SHA-256 hashes of successive chunks from *file_path*. - - Unlike ``compute_sha256``, which returns a **single** hash over the - entire file, this generator yields ``(chunk_bytes, partial_hex)`` - pairs as each chunk is read. After the generator is exhausted the - last ``partial_hex`` value equals the SHA-256 of the whole file. - - Parameters - ---------- - file_path: - Path to the file to hash. - chunk_size: - Read buffer size in bytes (default 8 192). - - Yields - ------ - (chunk_bytes, running_hex) : Tuple[bytes, str] - *chunk_bytes* — the raw bytes just read. - *running_hex* — SHA-256 hex digest of **all bytes read so far** - (i.e. the rolling hash after absorbing *chunk_bytes*). - - Examples - -------- - >>> *_, (_, final_hash) = stream_sha256("data/raw/dump.xml.bz2") - >>> print(final_hash) # same value as compute_sha256(file_path=...) - """ - if chunk_size <= 0: - raise ValueError("chunk_size must be a positive integer") - - path = Path(file_path) - if not path.exists(): - raise FileNotFoundError(f"File not found: {path}") - - sha256 = hashlib.sha256() - with path.open("rb") as f: - while True: - chunk = f.read(chunk_size) - if not chunk: - break - sha256.update(chunk) - yield chunk, sha256.hexdigest() - - -def compute_merkle_root_streaming( - chunks: Iterable[bytes], -) -> str: - """ - Compute a Merkle root over an **arbitrary iterable of byte chunks** - without requiring a seekable file on disk. - - This is the streaming counterpart to ``compute_merkle_root``. It - accepts any iterable — a ``stream_chunks()`` generator, a network - socket, a list of pre-computed blobs, etc. — so callers are not - restricted to file-backed data. - - The leaf hashes and tree construction follow the exact same algorithm - as ``compute_merkle_root`` to ensure identical roots for identical - content. - - Parameters - ---------- - chunks: - Any iterable that yields ``bytes`` objects. Each object is - treated as one Merkle leaf. - - Returns - ------- - str - 64-character lowercase hex SHA-256 Merkle root. - Returns ``compute_sha256(data=b"")`` for an empty iterable. - - Examples - -------- - >>> root = compute_merkle_root_streaming(stream_chunks("dump.xml.bz2")) - >>> assert root == compute_merkle_root("dump.xml.bz2") - """ - leaves: List[bytes] = [] - - for chunk in chunks: - leaf_hex = compute_sha256(data=chunk) - leaves.append(bytes.fromhex(leaf_hex)) - - if not leaves: - return compute_sha256(data=b"") - - while len(leaves) > 1: - next_level: List[bytes] = [] - for i in range(0, len(leaves), 2): - left = leaves[i] - right = leaves[i + 1] if i + 1 < len(leaves) else left - combined = left + right - parent_hex = compute_sha256(data=combined) - next_level.append(bytes.fromhex(parent_hex)) - leaves = next_level - - return leaves[0].hex() - - -def stream_extract_text_from_xml( - input_path: Union[str, Path], -) -> Generator[str, None, None]: - """ - Stream cleaned article texts from a Wikipedia XML dump without writing - any output file. - - This is the generator-based (streaming) counterpart to - ``extract_text_from_xml``. It yields one cleaned plain-text string - per Wikipedia article, keeping heap usage at O(1) regardless of dump - size. - - Supports both plain ``.xml`` and bzip2-compressed ``.xml.bz2`` inputs - by sniffing the first three bytes for the BZh magic header — exactly - the same auto-detection logic used in ``extract_text_from_xml``. - - Parameters - ---------- - input_path: - Path to a Wikipedia XML dump (compressed or uncompressed). - - Yields - ------ - str - Cleaned plain-text content of one Wikipedia article. - Articles that are empty after cleaning are silently skipped. - - Raises - ------ - FileNotFoundError - If *input_path* does not exist on disk. - ET.ParseError - If the XML stream is structurally malformed. - - Examples - -------- - >>> total_chars = sum(len(t) for t in stream_extract_text_from_xml("dump.xml.bz2")) - >>> print(f"Total characters: {total_chars:,}") - """ - input_path = Path(input_path) - - if not input_path.exists(): - raise FileNotFoundError(f"Dump file not found: {input_path}") - - with open(input_path, "rb") as probe: - is_bz2 = probe.read(3) == b"BZh" - - open_func = bz2.open if is_bz2 else open - - with open_func(input_path, "rb") as f: - context = ET.iterparse(f, events=("end",)) - for _, elem in context: - if elem.tag.endswith("page"): - text_elem = elem.find(".//{*}text") - raw_text: str = "" - if text_elem is not None and text_elem.text: - raw_text = text_elem.text - elem.clear() - - if not raw_text: - continue - - cleaned = clean_wikitext(raw_text) - if cleaned: - yield cleaned - - logger.info("Finished streaming articles from '%s'.", input_path.name) - - if __name__ == "__main__": if len(sys.argv) < 2: print("Usage: python -m openverifiablellm.utils ") From 6bafa675280446584c6f641ede304128ee7cd7e1 Mon Sep 17 00:00:00 2001 From: Muneerali199 Date: Mon, 9 Mar 2026 15:35:03 +0530 Subject: [PATCH 4/7] fix(benchmark): use defusedxml and extract_text_from_xml(stream=True); add streaming XML integration tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - benchmark.py: replace unsafe xml.etree.ElementTree with defusedxml.ElementTree to prevent XML parser attacks (CodeRabbit critical finding). - benchmark.py: replace broken import of stream_text_from_xml (from the now-deleted streaming_utils.py) with extract_text_from_xml from openverifiablellm.utils, using stream=True to preserve identical streaming behaviour. - tests/test_merkle.py: add TestStreamingXmlIntegration class with two tests: (1) streaming root hash matches batch reference root for a tiny XML fixture, (2) streamed texts have wikitext markup stripped — exercises the full extract_text_from_xml(stream=True) + IncrementalMerkleTree pipeline. --- openverifiablellm/benchmark.py | 11 ++-- tests/test_merkle.py | 91 ++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 5 deletions(-) diff --git a/openverifiablellm/benchmark.py b/openverifiablellm/benchmark.py index ef4d6da..d51bb2b 100644 --- a/openverifiablellm/benchmark.py +++ b/openverifiablellm/benchmark.py @@ -36,13 +36,12 @@ import sys import time import tracemalloc -import xml.etree.ElementTree as ET +import defusedxml.ElementTree as ET from pathlib import Path from typing import List, Optional, Tuple from openverifiablellm.incremental_merkle import IncrementalMerkleTree -from openverifiablellm.streaming_utils import stream_text_from_xml -from openverifiablellm.utils import clean_wikitext +from openverifiablellm.utils import clean_wikitext, extract_text_from_xml logging.basicConfig(level=logging.WARNING) logger = logging.getLogger(__name__) @@ -149,7 +148,9 @@ def _run_new_way(file_path: Path) -> BenchmarkResult: tree = IncrementalMerkleTree() - for article_text in stream_text_from_xml(str(file_path)): + stream = extract_text_from_xml(file_path, stream=True) + assert stream is not None, "extract_text_from_xml must return a generator when stream=True" + for article_text in stream: tree.append_leaf(article_text) root_hex: str = tree.get_root_hash() or hashlib.sha256(b"").hexdigest() @@ -209,7 +210,7 @@ def _render_markdown_table( "- *Wall-clock time* is measured with `time.perf_counter` on a single run.", " For publication-quality numbers repeat 3× and report median ± std-dev.", "- The Old Way intentionally omits `elem.clear()` to reproduce the OOM behaviour.", - "- The New Way uses `stream_text_from_xml` + `IncrementalMerkleTree` from this PR.", + "- The New Way uses `extract_text_from_xml(..., stream=True)` + `IncrementalMerkleTree` from this PR.", "", ] return "\n".join(lines) diff --git a/tests/test_merkle.py b/tests/test_merkle.py index 9adcc2b..fbbdd86 100644 --- a/tests/test_merkle.py +++ b/tests/test_merkle.py @@ -19,11 +19,15 @@ """ import hashlib +import io +import textwrap +from pathlib import Path from typing import List, Optional import pytest from openverifiablellm.incremental_merkle import IncrementalMerkleTree +from openverifiablellm.utils import extract_text_from_xml # =========================================================================== @@ -362,3 +366,90 @@ def test_repr_contains_leaf_count(self) -> None: r = repr(tree) assert "leaves=7" in r assert "IncrementalMerkleTree" in r + + +# =========================================================================== +# Integration test: extract_text_from_xml(stream=True) + IncrementalMerkleTree +# =========================================================================== + +# Minimal MediaWiki-style XML dump with 3 articles. +_WIKI_XML_FIXTURE = textwrap.dedent("""\ + + + + Alpha + Alpha article text about [[cats]] and {{tmpl}}. + + + Beta + Beta article text about ref dogs. + + + Gamma + Gamma article text: simple plain text only. + + +""") + + +class TestStreamingXmlIntegration: + """ + Exercises extract_text_from_xml(..., stream=True) end-to-end by feeding + its output into IncrementalMerkleTree and comparing with the batch + reference implementation. + + Uses a tiny in-memory XML fixture so the test is fast and requires no + external files. + """ + + def _write_xml_fixture(self, tmp_path) -> Path: + """Write the XML fixture to a temp file and return its path.""" + p = tmp_path / "wiki_test.xml" + p.write_text(_WIKI_XML_FIXTURE, encoding="utf-8") + return p + + def test_streaming_root_matches_batch_root(self, tmp_path) -> None: + """ + extract_text_from_xml(stream=True) must yield the same articles as + iterating manually, and IncrementalMerkleTree fed from the stream + must produce the same root as the batch reference. + """ + xml_path = self._write_xml_fixture(tmp_path) + + # Collect streamed texts + gen = extract_text_from_xml(xml_path, stream=True) + assert gen is not None, "extract_text_from_xml must return a generator when stream=True" + streamed_texts = list(gen) + + assert len(streamed_texts) > 0, "Fixture must yield at least one article" + + # Batch reference root + expected_root = batch_merkle_root(streamed_texts) + assert expected_root is not None + + # Incremental root built by re-streaming + gen2 = extract_text_from_xml(xml_path, stream=True) + assert gen2 is not None + tree = IncrementalMerkleTree() + for text in gen2: + tree.append_leaf(text) + + assert tree.get_root_hash() == expected_root, ( + "IncrementalMerkleTree root does not match batch root " + "when fed from extract_text_from_xml(stream=True)" + ) + + def test_streaming_yields_cleaned_text(self, tmp_path) -> None: + """ + Streamed article texts must have wikitext markup removed + (no {{ }}, [[ ]], or tags). + """ + xml_path = self._write_xml_fixture(tmp_path) + gen = extract_text_from_xml(xml_path, stream=True) + assert gen is not None + texts = list(gen) + + for text in texts: + assert "{{" not in text, "Templates should be stripped" + assert "[[" not in text, "Links should be stripped" + assert "" not in text, "Ref tags should be stripped" From 9c947a2706e0e6e4d1f65610e4962e91b1d3a906 Mon Sep 17 00:00:00 2001 From: Muneerali199 Date: Mon, 9 Mar 2026 16:22:13 +0530 Subject: [PATCH 5/7] fix(benchmark): multi-trial subprocess isolation, empty-article sentinel, and fixture ref escaping - benchmark.py: run_benchmark() now spawns each measurement in an isolated subprocess via --_mode flag, alternating old/new order per trial to eliminate OS pagecache and allocator warm-up bias; reports median time and RAM across all trials (default 3). - benchmark.py: empty-article runs now return root=None (sentinel) instead of the spurious hashlib.sha256(b'').hexdigest(); callers abort with a clear error when article_count==0 so zero-article runs are never silently passed off as a valid benchmark result. - tests/test_merkle.py: fixed _WIKI_XML_FIXTURE Beta article to use XML-escaped markup (<ref>ref</ref>) so the literal '...' string reaches text_elem.text and actually exercises clean_wikitext's ref-stripping logic, rather than being silently consumed by the XML parser. --- openverifiablellm/benchmark.py | 196 ++++++++++++++++++++++++++++----- tests/test_merkle.py | 2 +- 2 files changed, 167 insertions(+), 31 deletions(-) diff --git a/openverifiablellm/benchmark.py b/openverifiablellm/benchmark.py index d51bb2b..1b777ff 100644 --- a/openverifiablellm/benchmark.py +++ b/openverifiablellm/benchmark.py @@ -50,10 +50,11 @@ # Type alias for a benchmark result row # --------------------------------------------------------------------------- BenchmarkResult = Tuple[ - str, # approach label - float, # wall-clock seconds - float, # peak RAM in MB - str, # root hash (hex) + str, # approach label + float, # wall-clock seconds + float, # peak RAM in MB + Optional[str], # root hash (hex), or None when no articles were found + int, # article count ] @@ -98,6 +99,8 @@ def _run_old_way(file_path: Path) -> BenchmarkResult: # old code that leaks every parsed element into memory. # ----- Step 2: build Merkle tree from the in-memory list ----- + article_count = len(all_texts) + # Hash each article text to a leaf leaves: List[bytes] = [ hashlib.sha256(t.encode("utf-8")).digest() for t in all_texts @@ -105,7 +108,9 @@ def _run_old_way(file_path: Path) -> BenchmarkResult: # Batch construction: classic bottom-up Merkle tree if not leaves: - root_hex = hashlib.sha256(b"").hexdigest() + # Surface zero-article runs explicitly rather than producing a + # spurious matching root hash. + root_hex: Optional[str] = None else: current_level = leaves while len(current_level) > 1: @@ -126,6 +131,7 @@ def _run_old_way(file_path: Path) -> BenchmarkResult: round(t_end - t_start, 3), round(_bytes_to_mb(peak_bytes), 2), root_hex, + article_count, ) @@ -147,13 +153,16 @@ def _run_new_way(file_path: Path) -> BenchmarkResult: t_start = time.perf_counter() tree = IncrementalMerkleTree() + article_count = 0 stream = extract_text_from_xml(file_path, stream=True) assert stream is not None, "extract_text_from_xml must return a generator when stream=True" for article_text in stream: tree.append_leaf(article_text) + article_count += 1 - root_hex: str = tree.get_root_hash() or hashlib.sha256(b"").hexdigest() + # Surface zero-article runs as None rather than a spurious sha256(b"") hash. + root_hex: Optional[str] = tree.get_root_hash() if article_count > 0 else None t_end = time.perf_counter() _, peak_bytes = tracemalloc.get_traced_memory() @@ -164,6 +173,7 @@ def _run_new_way(file_path: Path) -> BenchmarkResult: round(t_end - t_start, 3), round(_bytes_to_mb(peak_bytes), 2), root_hex, + article_count, ) @@ -175,40 +185,48 @@ def _render_markdown_table( old: BenchmarkResult, new: BenchmarkResult, file_name: str, + trials: int, ) -> str: """ Render a GFM Markdown table suitable for direct use in a GitHub PR. Calculates speed-up and RAM reduction ratios and appends a legend. + Results are aggregated medians across multiple alternating-order trials. """ - label_old, time_old, ram_old, hash_old = old - label_new, time_new, ram_new, hash_new = new + label_old, time_old, ram_old, hash_old, count_old = old + label_new, time_new, ram_new, hash_new, count_new = new # Guard against division by zero on extremely fast runs time_ratio = (time_old / time_new) if time_new > 0 else float("inf") ram_ratio = (ram_old / ram_new) if ram_new > 0 else float("inf") - hashes_match = (hash_old == hash_new) - hash_verdict = "YES — identical root hash" if hashes_match else "NO — MISMATCH (investigate!)" + hash_old_str = hash_old if hash_old is not None else "N/A (0 articles)" + hash_new_str = hash_new if hash_new is not None else "N/A (0 articles)" + hashes_match = (hash_old == hash_new) and hash_old is not None + hash_verdict = "YES — identical root hash" if hashes_match else ( + "NO — MISMATCH (investigate!)" if hash_old is not None else "N/A — no articles processed" + ) lines = [ "", "## Benchmark Results", "", f"> **Input file:** `{file_name}` ", + f"> **Trials:** {trials} (alternating order, median reported) ", + f"> **Articles processed:** old={count_old}, new={count_new} ", f"> **Root hashes match:** {hash_verdict}", "", "| Metric | Old Way (in-memory) | New Way (streaming) | Improvement |", "|-------------------------------|--------------------:|--------------------:|--------------------|", f"| Wall-clock time (s) | `{time_old:>10.3f}` | `{time_new:>10.3f}` | **{time_ratio:,.1f}× faster** |", f"| Peak RAM usage (MB) | `{ram_old:>10.2f}` | `{ram_new:>10.2f}` | **{ram_ratio:,.1f}× less RAM** |", - f"| Root hash | `{hash_old[:16]}…` | `{hash_new[:16]}…` | {'Match' if hashes_match else 'MISMATCH'} |", + f"| Root hash | `{hash_old_str[:16]}…` | `{hash_new_str[:16]}…` | {'Match' if hashes_match else 'MISMATCH'} |", "", "### Notes", "- *Peak RAM* is measured with `tracemalloc` (Python heap only; does not include", " OS-level buffers or the bzip2 decompressor's internal state).", - "- *Wall-clock time* is measured with `time.perf_counter` on a single run.", - " For publication-quality numbers repeat 3× and report median ± std-dev.", + f"- *Wall-clock time* is the median of {trials} isolated subprocess trials with", + " alternating run order to minimise OS pagecache and warm-up bias.", "- The Old Way intentionally omits `elem.clear()` to reproduce the OOM behaviour.", "- The New Way uses `extract_text_from_xml(..., stream=True)` + `IncrementalMerkleTree` from this PR.", "", @@ -220,13 +238,14 @@ def _render_terminal_table( old: BenchmarkResult, new: BenchmarkResult, file_name: str, + trials: int, ) -> str: """Plain-text box table for terminal output (complements the Markdown table).""" - label_old, time_old, ram_old, hash_old = old - label_new, time_new, ram_new, hash_new = new + label_old, time_old, ram_old, hash_old, count_old = old + label_new, time_new, ram_new, hash_new, count_new = new time_ratio = (time_old / time_new) if time_new > 0 else float("inf") ram_ratio = (ram_old / ram_new) if ram_new > 0 else float("inf") - hashes_match = hash_old == hash_new + hashes_match = (hash_old == hash_new) and hash_old is not None w = 90 sep = "─" * w @@ -236,52 +255,156 @@ def row(col1: str, col2: str, col3: str, col4: str = "") -> str: lines = [ f"┌{sep}┐", - f"│{'BEFORE vs. AFTER — ' + file_name:^{w}}│", + f"│{'BEFORE vs. AFTER — ' + file_name + f' ({trials} trials, median)':^{w}}│", f"├{sep}┤", row("Metric", "Old Way", "New Way", "Improvement"), f"├{sep}┤", row("Wall-clock time (s)", f"{time_old:.3f} s", f"{time_new:.3f} s", f"{time_ratio:,.1f}x faster"), row("Peak RAM (MB)", f"{ram_old:.2f} MB", f"{ram_new:.2f} MB", f"{ram_ratio:,.1f}x less"), + row("Articles processed", str(count_old), str(count_new), ""), row("Root hashes match", "", "", "YES" if hashes_match else "NO — MISMATCH"), f"└{sep}┘", ] return "\n".join(lines) +# =========================================================================== +# Subprocess entry point (called by each isolated trial) +# =========================================================================== + +def _run_benchmark_mode(mode: str, file_path: str) -> None: + """ + Single-mode entry point invoked inside an isolated subprocess per trial. + + Prints a JSON line to stdout with keys: label, time, ram, root, articles. + The parent process parses these lines to aggregate trial results. + """ + import json + path = Path(file_path) + if mode == "old": + result = _run_old_way(path) + elif mode == "new": + result = _run_new_way(path) + else: + raise ValueError(f"Unknown mode: {mode!r}") + + label, elapsed, ram, root, articles = result + print(json.dumps({ + "label": label, + "time": elapsed, + "ram": ram, + "root": root, + "articles": articles, + })) + + # =========================================================================== # Main entry point # =========================================================================== -def run_benchmark(file_path: str) -> None: +def run_benchmark(file_path: str, trials: int = 3) -> None: """ - Execute both benchmarks sequentially and print the results. + Execute both benchmarks across multiple isolated trials with alternating + order to minimise OS pagecache and allocator warm-up bias. + + Each trial spawns a fresh Python subprocess so memory and file-cache state + are fully isolated between measurements. The order of old-vs-new is + reversed on odd-numbered trials. Median time and peak-RAM across all + trials are reported. Parameters ---------- file_path: Path to the Wikipedia ``.xml.bz2`` dump. + trials: + Number of measurement trials (default 3; must be odd for a clean median). """ + import json + import statistics + import subprocess + path = Path(file_path) if not path.exists(): print(f"[ERROR] File not found: {path}", file=sys.stderr) sys.exit(1) - print(f"\nRunning OLD WAY benchmark on: {path.name}") - print(" (This may take several minutes and use significant RAM …)\n") - old_result = _run_old_way(path) - print(f" Done. Time={old_result[1]:.3f}s Peak RAM={old_result[2]:.2f} MB") + old_times: List[float] = [] + old_rams: List[float] = [] + new_times: List[float] = [] + new_rams: List[float] = [] + old_root: Optional[str] = None + new_root: Optional[str] = None + old_articles = 0 + new_articles = 0 + + print(f"\nRunning {trials}-trial benchmark on: {path.name}") + print(" Each trial runs in an isolated subprocess to avoid pagecache bias.\n") + + for i in range(trials): + # Alternate order: even trials run old→new, odd trials run new→old + order = ["old", "new"] if i % 2 == 0 else ["new", "old"] + + for mode in order: + label = "OLD WAY" if mode == "old" else "NEW WAY" + print(f" Trial {i + 1}/{trials} — {label} …", end=" ", flush=True) + + proc = subprocess.run( + [ + sys.executable, "-m", "openverifiablellm.benchmark", + "--_mode", mode, file_path, + ], + capture_output=True, + text=True, + ) + if proc.returncode != 0: + print(f"\n[ERROR] Subprocess failed (mode={mode}):\n{proc.stderr}", file=sys.stderr) + sys.exit(1) + + data = json.loads(proc.stdout.strip()) + print(f"time={data['time']:.3f}s ram={data['ram']:.2f}MB") + + if mode == "old": + old_times.append(data["time"]) + old_rams.append(data["ram"]) + old_root = data["root"] + old_articles = data["articles"] + else: + new_times.append(data["time"]) + new_rams.append(data["ram"]) + new_root = data["root"] + new_articles = data["articles"] + + # Abort if either run found zero articles — spurious matching roots. + if old_articles == 0 or new_articles == 0: + print( + f"\n[ERROR] Zero articles processed " + f"(old={old_articles}, new={new_articles}). " + "Cannot produce meaningful benchmark results.", + file=sys.stderr, + ) + sys.exit(1) - print(f"\nRunning NEW WAY benchmark on: {path.name}") - print(" (Streaming — should use constant, minimal RAM …)\n") - new_result = _run_new_way(path) - print(f" Done. Time={new_result[1]:.3f}s Peak RAM={new_result[2]:.2f} MB") + old_result: BenchmarkResult = ( + "Old Way (in-memory)", + round(statistics.median(old_times), 3), + round(statistics.median(old_rams), 2), + old_root, + old_articles, + ) + new_result: BenchmarkResult = ( + "New Way (streaming)", + round(statistics.median(new_times), 3), + round(statistics.median(new_rams), 2), + new_root, + new_articles, + ) # Print terminal table print() - print(_render_terminal_table(old_result, new_result, path.name)) + print(_render_terminal_table(old_result, new_result, path.name, trials)) # Print GitHub-Flavored Markdown table - md = _render_markdown_table(old_result, new_result, path.name) + md = _render_markdown_table(old_result, new_result, path.name, trials) print("\n" + "=" * 60) print("Copy the block below into your GitHub Pull Request:") print("=" * 60) @@ -299,8 +422,21 @@ def main(argv: Optional[List[str]] = None) -> None: "file_path", help="Path to the Wikipedia XML.bz2 dump file (e.g. simplewiki-20260201-....xml.bz2)", ) + parser.add_argument( + "--trials", + type=int, + default=3, + help="Number of isolated subprocess trials (default: 3)", + ) + # Internal flag used by the subprocess runner — not intended for direct use. + parser.add_argument("--_mode", choices=["old", "new"], help=argparse.SUPPRESS) args = parser.parse_args(argv) - run_benchmark(args.file_path) + + if args._mode: + # Running as a subprocess trial — just measure and print JSON. + _run_benchmark_mode(args._mode, args.file_path) + else: + run_benchmark(args.file_path, trials=args.trials) if __name__ == "__main__": diff --git a/tests/test_merkle.py b/tests/test_merkle.py index fbbdd86..3d54b6f 100644 --- a/tests/test_merkle.py +++ b/tests/test_merkle.py @@ -382,7 +382,7 @@ def test_repr_contains_leaf_count(self) -> None: Beta - Beta article text about ref dogs. + Beta article text about <ref>ref</ref> dogs. Gamma From 900c92dcc55b64311d58099ced973a93735f4e3f Mon Sep 17 00:00:00 2001 From: Muneerali199 Date: Thu, 12 Mar 2026 15:21:40 +0530 Subject: [PATCH 6/7] fix: ruff-format all changed files and pass write_manifest=True in test fixture - Apply ruff format to utils.py (trailing whitespace / blank-line fixes) - Fix import ordering in benchmark.py and test_merkle.py (ruff I001) - Remove unused 'io' import from test_merkle.py (ruff F401) - Pass write_manifest=True in test_verify.run_preprocessing so the manifest is actually written and all 37 verify tests pass --- openverifiablellm/benchmark.py | 80 ++++++++++++++++++++++------------ openverifiablellm/utils.py | 2 + tests/test_merkle.py | 31 +++++++------ tests/test_verify.py | 2 +- 4 files changed, 73 insertions(+), 42 deletions(-) diff --git a/openverifiablellm/benchmark.py b/openverifiablellm/benchmark.py index 1b777ff..aefdecb 100644 --- a/openverifiablellm/benchmark.py +++ b/openverifiablellm/benchmark.py @@ -36,10 +36,11 @@ import sys import time import tracemalloc -import defusedxml.ElementTree as ET from pathlib import Path from typing import List, Optional, Tuple +import defusedxml.ElementTree as ET + from openverifiablellm.incremental_merkle import IncrementalMerkleTree from openverifiablellm.utils import clean_wikitext, extract_text_from_xml @@ -50,11 +51,11 @@ # Type alias for a benchmark result row # --------------------------------------------------------------------------- BenchmarkResult = Tuple[ - str, # approach label - float, # wall-clock seconds - float, # peak RAM in MB - Optional[str], # root hash (hex), or None when no articles were found - int, # article count + str, # approach label + float, # wall-clock seconds + float, # peak RAM in MB + Optional[str], # root hash (hex), or None when no articles were found + int, # article count ] @@ -69,6 +70,7 @@ def _bytes_to_mb(n_bytes: int) -> float: # APPROACH 1 — "Old Way" (in-memory) # =========================================================================== + def _run_old_way(file_path: Path) -> BenchmarkResult: """ Legacy approach: decompress the entire dump, collect ALL article texts @@ -102,9 +104,7 @@ def _run_old_way(file_path: Path) -> BenchmarkResult: article_count = len(all_texts) # Hash each article text to a leaf - leaves: List[bytes] = [ - hashlib.sha256(t.encode("utf-8")).digest() for t in all_texts - ] + leaves: List[bytes] = [hashlib.sha256(t.encode("utf-8")).digest() for t in all_texts] # Batch construction: classic bottom-up Merkle tree if not leaves: @@ -139,6 +139,7 @@ def _run_old_way(file_path: Path) -> BenchmarkResult: # APPROACH 2 — "New Way" (streaming) # =========================================================================== + def _run_new_way(file_path: Path) -> BenchmarkResult: """ New streaming approach: yield one article at a time from the generator @@ -181,6 +182,7 @@ def _run_new_way(file_path: Path) -> BenchmarkResult: # Reporting: GitHub-Flavored Markdown table # =========================================================================== + def _render_markdown_table( old: BenchmarkResult, new: BenchmarkResult, @@ -198,13 +200,19 @@ def _render_markdown_table( # Guard against division by zero on extremely fast runs time_ratio = (time_old / time_new) if time_new > 0 else float("inf") - ram_ratio = (ram_old / ram_new) if ram_new > 0 else float("inf") + ram_ratio = (ram_old / ram_new) if ram_new > 0 else float("inf") hash_old_str = hash_old if hash_old is not None else "N/A (0 articles)" hash_new_str = hash_new if hash_new is not None else "N/A (0 articles)" hashes_match = (hash_old == hash_new) and hash_old is not None - hash_verdict = "YES — identical root hash" if hashes_match else ( - "NO — MISMATCH (investigate!)" if hash_old is not None else "N/A — no articles processed" + hash_verdict = ( + "YES — identical root hash" + if hashes_match + else ( + "NO — MISMATCH (investigate!)" + if hash_old is not None + else "N/A — no articles processed" + ) ) lines = [ @@ -244,7 +252,7 @@ def _render_terminal_table( label_old, time_old, ram_old, hash_old, count_old = old label_new, time_new, ram_new, hash_new, count_new = new time_ratio = (time_old / time_new) if time_new > 0 else float("inf") - ram_ratio = (ram_old / ram_new) if ram_new > 0 else float("inf") + ram_ratio = (ram_old / ram_new) if ram_new > 0 else float("inf") hashes_match = (hash_old == hash_new) and hash_old is not None w = 90 @@ -259,7 +267,12 @@ def row(col1: str, col2: str, col3: str, col4: str = "") -> str: f"├{sep}┤", row("Metric", "Old Way", "New Way", "Improvement"), f"├{sep}┤", - row("Wall-clock time (s)", f"{time_old:.3f} s", f"{time_new:.3f} s", f"{time_ratio:,.1f}x faster"), + row( + "Wall-clock time (s)", + f"{time_old:.3f} s", + f"{time_new:.3f} s", + f"{time_ratio:,.1f}x faster", + ), row("Peak RAM (MB)", f"{ram_old:.2f} MB", f"{ram_new:.2f} MB", f"{ram_ratio:,.1f}x less"), row("Articles processed", str(count_old), str(count_new), ""), row("Root hashes match", "", "", "YES" if hashes_match else "NO — MISMATCH"), @@ -272,6 +285,7 @@ def row(col1: str, col2: str, col3: str, col4: str = "") -> str: # Subprocess entry point (called by each isolated trial) # =========================================================================== + def _run_benchmark_mode(mode: str, file_path: str) -> None: """ Single-mode entry point invoked inside an isolated subprocess per trial. @@ -280,6 +294,7 @@ def _run_benchmark_mode(mode: str, file_path: str) -> None: The parent process parses these lines to aggregate trial results. """ import json + path = Path(file_path) if mode == "old": result = _run_old_way(path) @@ -289,19 +304,24 @@ def _run_benchmark_mode(mode: str, file_path: str) -> None: raise ValueError(f"Unknown mode: {mode!r}") label, elapsed, ram, root, articles = result - print(json.dumps({ - "label": label, - "time": elapsed, - "ram": ram, - "root": root, - "articles": articles, - })) + print( + json.dumps( + { + "label": label, + "time": elapsed, + "ram": ram, + "root": root, + "articles": articles, + } + ) + ) # =========================================================================== # Main entry point # =========================================================================== + def run_benchmark(file_path: str, trials: int = 3) -> None: """ Execute both benchmarks across multiple isolated trials with alternating @@ -329,13 +349,13 @@ def run_benchmark(file_path: str, trials: int = 3) -> None: sys.exit(1) old_times: List[float] = [] - old_rams: List[float] = [] + old_rams: List[float] = [] new_times: List[float] = [] - new_rams: List[float] = [] + new_rams: List[float] = [] old_root: Optional[str] = None new_root: Optional[str] = None old_articles = 0 - new_articles = 0 + new_articles = 0 print(f"\nRunning {trials}-trial benchmark on: {path.name}") print(" Each trial runs in an isolated subprocess to avoid pagecache bias.\n") @@ -350,8 +370,12 @@ def run_benchmark(file_path: str, trials: int = 3) -> None: proc = subprocess.run( [ - sys.executable, "-m", "openverifiablellm.benchmark", - "--_mode", mode, file_path, + sys.executable, + "-m", + "openverifiablellm.benchmark", + "--_mode", + mode, + file_path, ], capture_output=True, text=True, @@ -387,14 +411,14 @@ def run_benchmark(file_path: str, trials: int = 3) -> None: old_result: BenchmarkResult = ( "Old Way (in-memory)", round(statistics.median(old_times), 3), - round(statistics.median(old_rams), 2), + round(statistics.median(old_rams), 2), old_root, old_articles, ) new_result: BenchmarkResult = ( "New Way (streaming)", round(statistics.median(new_times), 3), - round(statistics.median(new_rams), 2), + round(statistics.median(new_rams), 2), new_root, new_articles, ) diff --git a/openverifiablellm/utils.py b/openverifiablellm/utils.py index 3eab242..5a4b52f 100644 --- a/openverifiablellm/utils.py +++ b/openverifiablellm/utils.py @@ -224,6 +224,7 @@ def extract_text_from_xml( open_func = bz2.open if is_bz2 else open if stream: + def _generator() -> Generator[str, None, None]: with open_func(input_path, "rb") as f: context = ET.iterparse(f, events=("end",)) @@ -405,6 +406,7 @@ def compute_sha256( path = Path(file_path) # type: ignore[arg-type] if stream: + def _stream_gen() -> Generator[Tuple[bytes, str], None, None]: _sha256 = hashlib.sha256() with path.open("rb") as f: diff --git a/tests/test_merkle.py b/tests/test_merkle.py index 3d54b6f..34064a6 100644 --- a/tests/test_merkle.py +++ b/tests/test_merkle.py @@ -19,7 +19,6 @@ """ import hashlib -import io import textwrap from pathlib import Path from typing import List, Optional @@ -29,11 +28,11 @@ from openverifiablellm.incremental_merkle import IncrementalMerkleTree from openverifiablellm.utils import extract_text_from_xml - # =========================================================================== # Reference implementation: a classic batch Merkle tree # =========================================================================== + def _sha256_bytes(data: bytes) -> bytes: """Return raw 32-byte SHA-256 digest.""" return hashlib.sha256(data).digest() @@ -71,9 +70,7 @@ def batch_merkle_root(texts: List[str]) -> Optional[str]: return None # Leaf level: hash each string - level: List[bytes] = [ - _sha256_bytes(t.encode("utf-8")) for t in texts - ] + level: List[bytes] = [_sha256_bytes(t.encode("utf-8")) for t in texts] # Reduce level-by-level until only the root remains while len(level) > 1: @@ -92,6 +89,7 @@ def batch_merkle_root(texts: List[str]) -> Optional[str]: # Fixtures # =========================================================================== + @pytest.fixture def hundred_strings() -> List[str]: """ @@ -107,15 +105,14 @@ def hundred_strings() -> List[str]: # Core correctness test (the PRIMARY deliverable) # =========================================================================== + class TestIncrementalVsBatch: """ Verify that IncrementalMerkleTree produces the same root as the batch reference implementation for the same input sequence. """ - def test_root_hash_matches_batch_100_strings( - self, hundred_strings: List[str] - ) -> None: + def test_root_hash_matches_batch_100_strings(self, hundred_strings: List[str]) -> None: """ PRIMARY TEST: IncrementalMerkleTree root must exactly equal the batch Merkle root for the same 100 strings. @@ -124,7 +121,9 @@ def test_root_hash_matches_batch_100_strings( """ # Build batch reference root expected_root = batch_merkle_root(hundred_strings) - assert expected_root is not None, "batch_merkle_root should not return None for non-empty input" + assert expected_root is not None, ( + "batch_merkle_root should not return None for non-empty input" + ) # Build incremental root tree = IncrementalMerkleTree() @@ -192,7 +191,9 @@ def test_root_hash_matches_batch_odd_leaf_count(self) -> None: assert tree.get_root_hash() == expected - @pytest.mark.parametrize("n", [1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 17, 31, 32, 33, 64, 100, 128, 255, 256]) + @pytest.mark.parametrize( + "n", [1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 17, 31, 32, 33, 64, 100, 128, 255, 256] + ) def test_root_hash_matches_batch_parametric(self, n: int) -> None: """ Parametric sweep over various leaf counts including edge cases, @@ -205,15 +206,14 @@ def test_root_hash_matches_batch_parametric(self, n: int) -> None: for t in texts: tree.append_leaf(t) - assert tree.get_root_hash() == expected, ( - f"Root hash mismatch for n={n}" - ) + assert tree.get_root_hash() == expected, f"Root hash mismatch for n={n}" # =========================================================================== # Empty-tree behaviour # =========================================================================== + class TestEmptyTree: def test_get_root_hash_returns_none_when_empty(self) -> None: """An empty tree has no root — get_root_hash() must return None.""" @@ -233,6 +233,7 @@ def test_frontier_empty_when_empty(self) -> None: # Leaf count tracking # =========================================================================== + class TestLeafCount: def test_leaf_count_increments(self) -> None: tree = IncrementalMerkleTree() @@ -251,6 +252,7 @@ def test_leaf_count_matches_hundred(self, hundred_strings: List[str]) -> None: # Frontier size invariant # =========================================================================== + class TestFrontierInvariant: """ The frontier size equals the number of set bits in the binary @@ -274,6 +276,7 @@ def test_frontier_size_equals_popcount(self, n: int) -> None: # Determinism / reproducibility # =========================================================================== + class TestDeterminism: def test_same_input_same_root(self, hundred_strings: List[str]) -> None: """Two trees built from the same input must produce identical roots.""" @@ -328,6 +331,7 @@ def test_get_root_hash_is_non_destructive(self) -> None: # Hash format sanity checks # =========================================================================== + class TestHashFormat: def test_root_hash_is_64_char_hex(self) -> None: """SHA-256 produces 32 bytes → 64 lowercase hex characters.""" @@ -358,6 +362,7 @@ def test_single_leaf_root_equals_sha256_of_text(self) -> None: # repr() smoke test # =========================================================================== + class TestRepr: def test_repr_contains_leaf_count(self) -> None: tree = IncrementalMerkleTree() diff --git a/tests/test_verify.py b/tests/test_verify.py index 77817e0..931a66e 100644 --- a/tests/test_verify.py +++ b/tests/test_verify.py @@ -57,7 +57,7 @@ def run_preprocessing(tmp_dir: Path, dump: Path) -> None: original = os.getcwd() os.chdir(tmp_dir) try: - utils.extract_text_from_xml(dump) + utils.extract_text_from_xml(dump, write_manifest=True) finally: os.chdir(original) From fae34133c356cfe5af77f905846dac1728701ed5 Mon Sep 17 00:00:00 2001 From: Muneerali199 Date: Thu, 12 Mar 2026 15:38:44 +0530 Subject: [PATCH 7/7] fix: address all CodeRabbit inline review comments benchmark.py: - _run_old_way: replace hardcoded bz2.open() with magic-byte detection (same 3-byte BZh probe used by extract_text_from_xml) so plain .xml files are opened correctly instead of crashing - run_benchmark: replace json.loads(proc.stdout.strip()) with a reverse- scan that finds the last non-empty line parseable as JSON, tolerating any stray log/warning lines that may appear before the JSON payload utils.py: - extract_text_from_xml (stream=True): wrap the ET.iterparse loop in try/finally and call context.close() in the finally block so the iterparse internal state is released when the caller abandons the generator mid-stream or an exception propagates incremental_merkle.py: - Add public read-only property frontier_size -> int so callers and tests can observe frontier occupancy without accessing _frontier tests/test_merkle.py: - Replace all tree._frontier direct accesses with tree.frontier_size (covers test_root_hash_matches_batch_power_of_two, test_frontier_empty_when_empty, test_frontier_size_equals_popcount) --- openverifiablellm/benchmark.py | 29 ++++++++++++++++++++-- openverifiablellm/incremental_merkle.py | 13 +++++++++- openverifiablellm/utils.py | 32 +++++++++++++++---------- tests/test_merkle.py | 10 ++++---- 4 files changed, 64 insertions(+), 20 deletions(-) diff --git a/openverifiablellm/benchmark.py b/openverifiablellm/benchmark.py index aefdecb..b7f6592 100644 --- a/openverifiablellm/benchmark.py +++ b/openverifiablellm/benchmark.py @@ -86,7 +86,13 @@ def _run_old_way(file_path: Path) -> BenchmarkResult: # ----- Step 1: load all texts into memory ----- all_texts: List[str] = [] - with bz2.open(file_path, "rb") as raw: + # Detect compression by inspecting the bz2 magic bytes (same logic as + # extract_text_from_xml) so plain .xml files are also handled correctly. + with open(file_path, "rb") as _probe: + _is_bz2 = _probe.read(3) == b"BZh" + _open_func = bz2.open if _is_bz2 else open + + with _open_func(file_path, "rb") as raw: context = ET.iterparse(raw, events=("end",)) for _event, elem in context: if elem.tag.endswith("page"): @@ -384,7 +390,26 @@ def run_benchmark(file_path: str, trials: int = 3) -> None: print(f"\n[ERROR] Subprocess failed (mode={mode}):\n{proc.stderr}", file=sys.stderr) sys.exit(1) - data = json.loads(proc.stdout.strip()) + # Extract the last non-empty line that parses as valid JSON. + # This tolerates any stray log/warning lines on stdout that may + # appear before or after the single JSON payload line. + _stdout_lines = proc.stdout.splitlines() + data = None + for _line in reversed(_stdout_lines): + _line = _line.strip() + if _line: + try: + data = json.loads(_line) + break + except json.JSONDecodeError: + continue + if data is None: + print( + f"\n[ERROR] Could not find valid JSON in subprocess output " + f"(mode={mode}):\n{proc.stdout}", + file=sys.stderr, + ) + sys.exit(1) print(f"time={data['time']:.3f}s ram={data['ram']:.2f}MB") if mode == "old": diff --git a/openverifiablellm/incremental_merkle.py b/openverifiablellm/incremental_merkle.py index c181974..e682185 100644 --- a/openverifiablellm/incremental_merkle.py +++ b/openverifiablellm/incremental_merkle.py @@ -70,11 +70,11 @@ import hashlib from typing import Dict, Optional - # --------------------------------------------------------------------------- # Module-level helpers # --------------------------------------------------------------------------- + def _sha256_bytes(data: bytes) -> bytes: """Return the raw 32-byte SHA-256 digest of *data*.""" return hashlib.sha256(data).digest() @@ -104,6 +104,7 @@ def _combine(left: bytes, right: bytes) -> bytes: # IncrementalMerkleTree # --------------------------------------------------------------------------- + class IncrementalMerkleTree: """ An append-only Merkle tree with O(log N) time and space per operation. @@ -262,6 +263,16 @@ def get_root_hash(self) -> Optional[str]: # Convenience / introspection # ------------------------------------------------------------------ + @property + def frontier_size(self) -> int: + """Number of nodes currently stored in the frontier. + + Equals ``bin(leaf_count).count('1')`` — one node per set bit in the + binary representation of the leaf count. This is the canonical public + way to observe frontier occupancy without accessing ``_frontier`` directly. + """ + return len(self._frontier) + @property def leaf_count(self) -> int: """Total number of leaves that have been appended.""" diff --git a/openverifiablellm/utils.py b/openverifiablellm/utils.py index 5a4b52f..bf41c47 100644 --- a/openverifiablellm/utils.py +++ b/openverifiablellm/utils.py @@ -228,18 +228,26 @@ def extract_text_from_xml( def _generator() -> Generator[str, None, None]: with open_func(input_path, "rb") as f: context = ET.iterparse(f, events=("end",)) - for _, elem in context: - if elem.tag.endswith("page"): - text_elem = elem.find(".//{*}text") - raw_text: str = "" - if text_elem is not None and text_elem.text: - raw_text = text_elem.text - elem.clear() - if not raw_text: - continue - cleaned = clean_wikitext(raw_text) - if cleaned: - yield cleaned + try: + for _, elem in context: + if elem.tag.endswith("page"): + text_elem = elem.find(".//{*}text") + raw_text: str = "" + if text_elem is not None and text_elem.text: + raw_text = text_elem.text + elem.clear() + if not raw_text: + continue + cleaned = clean_wikitext(raw_text) + if cleaned: + yield cleaned + finally: + # Release iterparse internal state even if the caller + # abandons the generator mid-stream or an exception occurs. + try: + context.close() + except AttributeError: + pass logger.info("Finished streaming articles from '%s'.", input_path.name) return _generator() diff --git a/tests/test_merkle.py b/tests/test_merkle.py index 34064a6..f18a126 100644 --- a/tests/test_merkle.py +++ b/tests/test_merkle.py @@ -174,8 +174,8 @@ def test_root_hash_matches_batch_power_of_two(self) -> None: for t in texts: tree.append_leaf(t) - # For exactly 2^k leaves, the frontier should have exactly 1 node - assert len(tree._frontier) == 1, ( + # For exactly 2^k leaves the frontier collapses to a single node + assert tree.frontier_size == 1, ( "For a power-of-two leaf count the frontier should collapse to 1 node" ) assert tree.get_root_hash() == expected @@ -226,7 +226,7 @@ def test_leaf_count_zero_when_empty(self) -> None: def test_frontier_empty_when_empty(self) -> None: tree = IncrementalMerkleTree() - assert tree._frontier == {} + assert tree.frontier_size == 0 # =========================================================================== @@ -266,9 +266,9 @@ def test_frontier_size_equals_popcount(self, n: int) -> None: tree.append_leaf(f"x_{i}") expected_frontier_nodes = bin(n).count("1") - assert len(tree._frontier) == expected_frontier_nodes, ( + assert tree.frontier_size == expected_frontier_nodes, ( f"After {n} leaves (binary: {bin(n)}), frontier should have " - f"{expected_frontier_nodes} node(s), got {len(tree._frontier)}" + f"{expected_frontier_nodes} node(s), got {tree.frontier_size}" )