diff --git a/openverifiablellm/benchmark.py b/openverifiablellm/benchmark.py new file mode 100644 index 0000000..b7f6592 --- /dev/null +++ b/openverifiablellm/benchmark.py @@ -0,0 +1,492 @@ +""" +benchmark.py +============ +Before-vs-After resource consumption analysis. + +Compares the **legacy approach** (load everything into memory, hash at once) +against the **new streaming approach** (O(1) memory generator + Incremental +Merkle Tree) on the same Wikipedia dump. + +Metrics captured +---------------- +* Wall-clock execution time (seconds) via ``time.perf_counter``. +* Peak heap allocation (MB) via ``tracemalloc``. + +Output +------ +Prints a GitHub-flavored Markdown table to stdout so the result can be +copy-pasted directly into a Pull Request description. + +Usage +----- + # Download first (≈350 MB): + # wget https://dumps.wikimedia.org/simplewiki/20260201/simplewiki-20260201-pages-articles-multistream.xml.bz2 + + python -m openverifiablellm.benchmark simplewiki-20260201-pages-articles-multistream.xml.bz2 + + # Or via the scripts/ helper: + python scripts/benchmark.py +""" + +import argparse +import bz2 +import gc +import hashlib +import logging +import sys +import time +import tracemalloc +from pathlib import Path +from typing import List, Optional, Tuple + +import defusedxml.ElementTree as ET + +from openverifiablellm.incremental_merkle import IncrementalMerkleTree +from openverifiablellm.utils import clean_wikitext, extract_text_from_xml + +logging.basicConfig(level=logging.WARNING) +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Type alias for a benchmark result row +# --------------------------------------------------------------------------- +BenchmarkResult = Tuple[ + str, # approach label + float, # wall-clock seconds + float, # peak RAM in MB + Optional[str], # root hash (hex), or None when no articles were found + int, # article count +] + + +# --------------------------------------------------------------------------- +# Helper: convert tracemalloc peak bytes → MB +# --------------------------------------------------------------------------- +def _bytes_to_mb(n_bytes: int) -> float: + return n_bytes / (1024 * 1024) + + +# =========================================================================== +# APPROACH 1 — "Old Way" (in-memory) +# =========================================================================== + + +def _run_old_way(file_path: Path) -> BenchmarkResult: + """ + Legacy approach: decompress the entire dump, collect ALL article texts + into a Python list, then build a standard batch Merkle tree from the list. + + Memory profile: O(N) — every article text lives in RAM simultaneously. + Time profile : O(N) for loading + O(N log N) for tree construction. + """ + gc.collect() + tracemalloc.start() + t_start = time.perf_counter() + + # ----- Step 1: load all texts into memory ----- + all_texts: List[str] = [] + + # Detect compression by inspecting the bz2 magic bytes (same logic as + # extract_text_from_xml) so plain .xml files are also handled correctly. + with open(file_path, "rb") as _probe: + _is_bz2 = _probe.read(3) == b"BZh" + _open_func = bz2.open if _is_bz2 else open + + with _open_func(file_path, "rb") as raw: + context = ET.iterparse(raw, events=("end",)) + for _event, elem in context: + if elem.tag.endswith("page"): + text_elem = elem.find(".//{*}text") + if text_elem is None: + text_elem = elem.find(".//text") + if text_elem is not None and text_elem.text: + cleaned = clean_wikitext(text_elem.text) + if cleaned: + all_texts.append(cleaned) + # NOTE: No elem.clear() — intentionally simulating the + # old code that leaks every parsed element into memory. + + # ----- Step 2: build Merkle tree from the in-memory list ----- + article_count = len(all_texts) + + # Hash each article text to a leaf + leaves: List[bytes] = [hashlib.sha256(t.encode("utf-8")).digest() for t in all_texts] + + # Batch construction: classic bottom-up Merkle tree + if not leaves: + # Surface zero-article runs explicitly rather than producing a + # spurious matching root hash. + root_hex: Optional[str] = None + else: + current_level = leaves + while len(current_level) > 1: + next_level: List[bytes] = [] + for i in range(0, len(current_level), 2): + left = current_level[i] + right = current_level[i + 1] if i + 1 < len(current_level) else left + next_level.append(hashlib.sha256(left + right).digest()) + current_level = next_level + root_hex = current_level[0].hex() + + t_end = time.perf_counter() + _, peak_bytes = tracemalloc.get_traced_memory() + tracemalloc.stop() + + return ( + "Old Way (in-memory)", + round(t_end - t_start, 3), + round(_bytes_to_mb(peak_bytes), 2), + root_hex, + article_count, + ) + + +# =========================================================================== +# APPROACH 2 — "New Way" (streaming) +# =========================================================================== + + +def _run_new_way(file_path: Path) -> BenchmarkResult: + """ + New streaming approach: yield one article at a time from the generator + and feed it directly into the IncrementalMerkleTree. + + Memory profile: O(log N) — only the Merkle frontier is kept in RAM. + Time profile : O(N log N) — but with vastly lower constant factors + because no large list allocation occurs. + """ + gc.collect() + tracemalloc.start() + t_start = time.perf_counter() + + tree = IncrementalMerkleTree() + article_count = 0 + + stream = extract_text_from_xml(file_path, stream=True) + assert stream is not None, "extract_text_from_xml must return a generator when stream=True" + for article_text in stream: + tree.append_leaf(article_text) + article_count += 1 + + # Surface zero-article runs as None rather than a spurious sha256(b"") hash. + root_hex: Optional[str] = tree.get_root_hash() if article_count > 0 else None + + t_end = time.perf_counter() + _, peak_bytes = tracemalloc.get_traced_memory() + tracemalloc.stop() + + return ( + "New Way (streaming)", + round(t_end - t_start, 3), + round(_bytes_to_mb(peak_bytes), 2), + root_hex, + article_count, + ) + + +# =========================================================================== +# Reporting: GitHub-Flavored Markdown table +# =========================================================================== + + +def _render_markdown_table( + old: BenchmarkResult, + new: BenchmarkResult, + file_name: str, + trials: int, +) -> str: + """ + Render a GFM Markdown table suitable for direct use in a GitHub PR. + + Calculates speed-up and RAM reduction ratios and appends a legend. + Results are aggregated medians across multiple alternating-order trials. + """ + label_old, time_old, ram_old, hash_old, count_old = old + label_new, time_new, ram_new, hash_new, count_new = new + + # Guard against division by zero on extremely fast runs + time_ratio = (time_old / time_new) if time_new > 0 else float("inf") + ram_ratio = (ram_old / ram_new) if ram_new > 0 else float("inf") + + hash_old_str = hash_old if hash_old is not None else "N/A (0 articles)" + hash_new_str = hash_new if hash_new is not None else "N/A (0 articles)" + hashes_match = (hash_old == hash_new) and hash_old is not None + hash_verdict = ( + "YES — identical root hash" + if hashes_match + else ( + "NO — MISMATCH (investigate!)" + if hash_old is not None + else "N/A — no articles processed" + ) + ) + + lines = [ + "", + "## Benchmark Results", + "", + f"> **Input file:** `{file_name}` ", + f"> **Trials:** {trials} (alternating order, median reported) ", + f"> **Articles processed:** old={count_old}, new={count_new} ", + f"> **Root hashes match:** {hash_verdict}", + "", + "| Metric | Old Way (in-memory) | New Way (streaming) | Improvement |", + "|-------------------------------|--------------------:|--------------------:|--------------------|", + f"| Wall-clock time (s) | `{time_old:>10.3f}` | `{time_new:>10.3f}` | **{time_ratio:,.1f}× faster** |", + f"| Peak RAM usage (MB) | `{ram_old:>10.2f}` | `{ram_new:>10.2f}` | **{ram_ratio:,.1f}× less RAM** |", + f"| Root hash | `{hash_old_str[:16]}…` | `{hash_new_str[:16]}…` | {'Match' if hashes_match else 'MISMATCH'} |", + "", + "### Notes", + "- *Peak RAM* is measured with `tracemalloc` (Python heap only; does not include", + " OS-level buffers or the bzip2 decompressor's internal state).", + f"- *Wall-clock time* is the median of {trials} isolated subprocess trials with", + " alternating run order to minimise OS pagecache and warm-up bias.", + "- The Old Way intentionally omits `elem.clear()` to reproduce the OOM behaviour.", + "- The New Way uses `extract_text_from_xml(..., stream=True)` + `IncrementalMerkleTree` from this PR.", + "", + ] + return "\n".join(lines) + + +def _render_terminal_table( + old: BenchmarkResult, + new: BenchmarkResult, + file_name: str, + trials: int, +) -> str: + """Plain-text box table for terminal output (complements the Markdown table).""" + label_old, time_old, ram_old, hash_old, count_old = old + label_new, time_new, ram_new, hash_new, count_new = new + time_ratio = (time_old / time_new) if time_new > 0 else float("inf") + ram_ratio = (ram_old / ram_new) if ram_new > 0 else float("inf") + hashes_match = (hash_old == hash_new) and hash_old is not None + + w = 90 + sep = "─" * w + + def row(col1: str, col2: str, col3: str, col4: str = "") -> str: + return f"│ {col1:<28} │ {col2:>18} │ {col3:>18} │ {col4:<14} │" + + lines = [ + f"┌{sep}┐", + f"│{'BEFORE vs. AFTER — ' + file_name + f' ({trials} trials, median)':^{w}}│", + f"├{sep}┤", + row("Metric", "Old Way", "New Way", "Improvement"), + f"├{sep}┤", + row( + "Wall-clock time (s)", + f"{time_old:.3f} s", + f"{time_new:.3f} s", + f"{time_ratio:,.1f}x faster", + ), + row("Peak RAM (MB)", f"{ram_old:.2f} MB", f"{ram_new:.2f} MB", f"{ram_ratio:,.1f}x less"), + row("Articles processed", str(count_old), str(count_new), ""), + row("Root hashes match", "", "", "YES" if hashes_match else "NO — MISMATCH"), + f"└{sep}┘", + ] + return "\n".join(lines) + + +# =========================================================================== +# Subprocess entry point (called by each isolated trial) +# =========================================================================== + + +def _run_benchmark_mode(mode: str, file_path: str) -> None: + """ + Single-mode entry point invoked inside an isolated subprocess per trial. + + Prints a JSON line to stdout with keys: label, time, ram, root, articles. + The parent process parses these lines to aggregate trial results. + """ + import json + + path = Path(file_path) + if mode == "old": + result = _run_old_way(path) + elif mode == "new": + result = _run_new_way(path) + else: + raise ValueError(f"Unknown mode: {mode!r}") + + label, elapsed, ram, root, articles = result + print( + json.dumps( + { + "label": label, + "time": elapsed, + "ram": ram, + "root": root, + "articles": articles, + } + ) + ) + + +# =========================================================================== +# Main entry point +# =========================================================================== + + +def run_benchmark(file_path: str, trials: int = 3) -> None: + """ + Execute both benchmarks across multiple isolated trials with alternating + order to minimise OS pagecache and allocator warm-up bias. + + Each trial spawns a fresh Python subprocess so memory and file-cache state + are fully isolated between measurements. The order of old-vs-new is + reversed on odd-numbered trials. Median time and peak-RAM across all + trials are reported. + + Parameters + ---------- + file_path: + Path to the Wikipedia ``.xml.bz2`` dump. + trials: + Number of measurement trials (default 3; must be odd for a clean median). + """ + import json + import statistics + import subprocess + + path = Path(file_path) + if not path.exists(): + print(f"[ERROR] File not found: {path}", file=sys.stderr) + sys.exit(1) + + old_times: List[float] = [] + old_rams: List[float] = [] + new_times: List[float] = [] + new_rams: List[float] = [] + old_root: Optional[str] = None + new_root: Optional[str] = None + old_articles = 0 + new_articles = 0 + + print(f"\nRunning {trials}-trial benchmark on: {path.name}") + print(" Each trial runs in an isolated subprocess to avoid pagecache bias.\n") + + for i in range(trials): + # Alternate order: even trials run old→new, odd trials run new→old + order = ["old", "new"] if i % 2 == 0 else ["new", "old"] + + for mode in order: + label = "OLD WAY" if mode == "old" else "NEW WAY" + print(f" Trial {i + 1}/{trials} — {label} …", end=" ", flush=True) + + proc = subprocess.run( + [ + sys.executable, + "-m", + "openverifiablellm.benchmark", + "--_mode", + mode, + file_path, + ], + capture_output=True, + text=True, + ) + if proc.returncode != 0: + print(f"\n[ERROR] Subprocess failed (mode={mode}):\n{proc.stderr}", file=sys.stderr) + sys.exit(1) + + # Extract the last non-empty line that parses as valid JSON. + # This tolerates any stray log/warning lines on stdout that may + # appear before or after the single JSON payload line. + _stdout_lines = proc.stdout.splitlines() + data = None + for _line in reversed(_stdout_lines): + _line = _line.strip() + if _line: + try: + data = json.loads(_line) + break + except json.JSONDecodeError: + continue + if data is None: + print( + f"\n[ERROR] Could not find valid JSON in subprocess output " + f"(mode={mode}):\n{proc.stdout}", + file=sys.stderr, + ) + sys.exit(1) + print(f"time={data['time']:.3f}s ram={data['ram']:.2f}MB") + + if mode == "old": + old_times.append(data["time"]) + old_rams.append(data["ram"]) + old_root = data["root"] + old_articles = data["articles"] + else: + new_times.append(data["time"]) + new_rams.append(data["ram"]) + new_root = data["root"] + new_articles = data["articles"] + + # Abort if either run found zero articles — spurious matching roots. + if old_articles == 0 or new_articles == 0: + print( + f"\n[ERROR] Zero articles processed " + f"(old={old_articles}, new={new_articles}). " + "Cannot produce meaningful benchmark results.", + file=sys.stderr, + ) + sys.exit(1) + + old_result: BenchmarkResult = ( + "Old Way (in-memory)", + round(statistics.median(old_times), 3), + round(statistics.median(old_rams), 2), + old_root, + old_articles, + ) + new_result: BenchmarkResult = ( + "New Way (streaming)", + round(statistics.median(new_times), 3), + round(statistics.median(new_rams), 2), + new_root, + new_articles, + ) + + # Print terminal table + print() + print(_render_terminal_table(old_result, new_result, path.name, trials)) + + # Print GitHub-Flavored Markdown table + md = _render_markdown_table(old_result, new_result, path.name, trials) + print("\n" + "=" * 60) + print("Copy the block below into your GitHub Pull Request:") + print("=" * 60) + print(md) + + +def main(argv: Optional[List[str]] = None) -> None: + parser = argparse.ArgumentParser( + description=( + "Before-vs-After benchmark: in-memory vs streaming Merkle tree " + "on a Wikipedia XML.bz2 dump." + ) + ) + parser.add_argument( + "file_path", + help="Path to the Wikipedia XML.bz2 dump file (e.g. simplewiki-20260201-....xml.bz2)", + ) + parser.add_argument( + "--trials", + type=int, + default=3, + help="Number of isolated subprocess trials (default: 3)", + ) + # Internal flag used by the subprocess runner — not intended for direct use. + parser.add_argument("--_mode", choices=["old", "new"], help=argparse.SUPPRESS) + args = parser.parse_args(argv) + + if args._mode: + # Running as a subprocess trial — just measure and print JSON. + _run_benchmark_mode(args._mode, args.file_path) + else: + run_benchmark(args.file_path, trials=args.trials) + + +if __name__ == "__main__": + main() diff --git a/openverifiablellm/incremental_merkle.py b/openverifiablellm/incremental_merkle.py new file mode 100644 index 0000000..e682185 --- /dev/null +++ b/openverifiablellm/incremental_merkle.py @@ -0,0 +1,293 @@ +""" +incremental_merkle.py +===================== +An append-only, O(log N) space Merkle Tree using the "Merkle Frontier". + +Background: The Merkle Frontier Algorithm +------------------------------------------ +A classical Merkle tree requires ALL leaf hashes to be stored so the +tree can be reconstructed level by level. For N leaves that is O(N) +memory — completely unacceptable when N is in the millions. + +The Merkle Frontier (also called a "Merkle accumulator" in certificate +transparency literature) solves this by exploiting one structural +property of binary Merkle trees: + + The only nodes you ever need to recompute the root are the + **rightmost unpaired nodes at each depth** — the "frontier". + +Concretely: after inserting N leaves the frontier contains at most +⌈log₂ N⌉ hashes — one per set bit in the binary representation of N. + +Example with N = 5 (binary: 101) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Leaf indices: L0 L1 L2 L3 L4 + +Full tree structure: + ROOT + / \\ + H(0-3) H(4) + / \\ + H(0-1) H(2-3) + / \\ / \\ + L0 L1 L2 L3 + +After appending L0…L4 the frontier is: + depth 2 → H(0-3) (covers 4 leaves — a complete sub-tree) + depth 0 → L4 (a single unpaired leaf at the right edge) + +To get the root we just fold right: root = H( H(0-3) || L4 ) + +Space used by frontier: 2 hashes (⌈log₂ 5⌉ = 3, but only 2 bits are +set in 5 = 0b101, so there are only 2 frontier nodes). + +Root computation +---------------- +``get_root_hash()`` folds the frontier from the **lowest depth** to the +**highest depth**, combining nodes pair-wise: + + current = frontier[lowest_depth] + for d in range(lowest_depth + 1, max_depth + 1): + if frontier[d] exists: + current = sha256( frontier[d] || current ) + else: + current = sha256( current || current ) # duplicate (odd node) + +The final ``current`` value is the root hash. + +Usage +----- + from openverifiablellm.incremental_merkle import IncrementalMerkleTree + + tree = IncrementalMerkleTree() + for text in stream_text_from_xml("dump.xml.bz2"): + tree.append_leaf(text) + + root = tree.get_root_hash() + print(f"Merkle root: {root}") +""" + +import hashlib +from typing import Dict, Optional + +# --------------------------------------------------------------------------- +# Module-level helpers +# --------------------------------------------------------------------------- + + +def _sha256_bytes(data: bytes) -> bytes: + """Return the raw 32-byte SHA-256 digest of *data*.""" + return hashlib.sha256(data).digest() + + +def _combine(left: bytes, right: bytes) -> bytes: + """ + Combine two 32-byte node hashes into a parent hash. + + The canonical MediaWiki / Bitcoin-style combination: + parent = SHA256( left_bytes || right_bytes ) + + Parameters + ---------- + left, right: + Raw 32-byte digest values (NOT hex strings). + + Returns + ------- + bytes + 32-byte digest of the concatenation. + """ + return _sha256_bytes(left + right) + + +# --------------------------------------------------------------------------- +# IncrementalMerkleTree +# --------------------------------------------------------------------------- + + +class IncrementalMerkleTree: + """ + An append-only Merkle tree with O(log N) time and space per operation. + + State + ----- + _frontier : Dict[int, bytes] + Maps ``depth`` → 32-byte hash. + Depth 0 = leaf level. Higher depth = closer to the root. + The frontier holds at most one node per depth; a node at depth d + represents a **complete** subtree of height d (covering 2**d leaves). + + _leaf_count : int + Total number of leaves appended so far. Used only for logging/info. + + Invariant + --------- + After appending N leaves, ``_frontier`` contains exactly the nodes + corresponding to the set bits in the binary representation of N. + For example N = 6 = 0b110 → frontier has nodes at depth 1 and depth 2. + """ + + def __init__(self) -> None: + # depth → 32-byte node hash. Only "complete" subtree roots are stored. + self._frontier: Dict[int, bytes] = {} + self._leaf_count: int = 0 + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def append_leaf(self, text_chunk: str) -> None: + """ + Hash *text_chunk* and insert it as the next leaf in the tree. + + Algorithm (O(log N) time, O(log N) space) + ------------------------------------------ + 1. Compute the SHA-256 hash of the UTF-8-encoded text. + 2. Start at depth 0 (leaf level) with the new hash as ``node``. + 3. While there is already a node stored at the current depth: + a. Combine the stored node (left) with our new node (right). + b. Remove the stored node from the frontier (the slot is now + "consumed" into a higher level). + c. Move one level up and continue with the combined hash. + 4. Store the final unconsumed node at its depth in the frontier. + + This mirrors how a binary counter increments: each carry bit + propagates up until it finds an empty slot. + + Parameters + ---------- + text_chunk: + Arbitrary Unicode string. Encoded to UTF-8 before hashing. + """ + # Step 1: hash the raw text to get a 32-byte leaf digest + new_node: bytes = _sha256_bytes(text_chunk.encode("utf-8")) + + # Step 2-4: propagate carries up the tree, exactly like binary addition + depth: int = 0 + while depth in self._frontier: + # Combine: existing left sibling || new right node → parent + left_sibling: bytes = self._frontier.pop(depth) + new_node = _combine(left_sibling, new_node) + depth += 1 + + # No existing node at this depth — park the new node here + self._frontier[depth] = new_node + self._leaf_count += 1 + + def get_root_hash(self) -> Optional[str]: + """ + Compute and return the current Merkle root hash as a hex string. + + This method is **non-destructive** — the frontier is not modified. + + Algorithm (O(log N)) + -------------------- + The frontier decomposes N leaves into complete power-of-two subtrees, + one per set bit of N. To collapse them into a single root we must + replicate the same "odd-node duplication" rule used by the classic + batch builder: + + When a level has an **odd** number of nodes, the rightmost node + is paired with itself: parent = combine(node, node). + + Concretely, the frontier nodes sit at various depths d₀ < d₁ < … < dₖ + (sorted ascending). We fold them right-to-left (lowest depth first), + promoting each partial subtree to the next depth by self-combining + before merging it with the larger complete subtree on its left: + + accumulator = frontier[d₀] + + for each successive depth dᵢ (i = 1 … k): + # Promote accumulator from its current depth up to dᵢ + # by repeatedly self-combining (mirroring the batch tree's + # odd-node duplication at each intermediate level). + while current_depth < dᵢ: + accumulator = combine(accumulator, accumulator) + current_depth += 1 + + # Merge: the complete subtree at dᵢ is on the LEFT + accumulator = combine(frontier[dᵢ], accumulator) + + Why self-combine? + ----------------- + In the batch tree, after all full pairs are consumed at level L, + any leftover (odd) node is duplicated before ascending. The frontier + encodes exactly those "leftover" nodes. If frontier[d₀] exists and + the next frontier node is at d₁ > d₀+1, the batch tree would have + duplicated the depth-d₀ subtree (d₁-d₀) times to produce a depth-d₁ + right child before combining with the depth-d₁ left sibling. + + Edge cases + ---------- + * Zero leaves → returns ``None``. + * Single leaf → returns SHA-256(leaf text). + * Power-of-two count → frontier has 1 node, returned directly. + + Returns + ------- + str or None + 64-character lowercase hex string, or ``None`` if empty. + """ + if not self._frontier: + return None + + # Sort depths ascending: smallest (rightmost partial) to largest (leftmost complete) + sorted_depths = sorted(self._frontier.keys()) + + # Seed the accumulator with the rightmost (lowest depth) frontier node + accumulator: bytes = self._frontier[sorted_depths[0]] + current_depth: int = sorted_depths[0] + + for target_depth in sorted_depths[1:]: + # ---------------------------------------------------------------- + # Promote accumulator to target_depth by self-combining. + # This mirrors the batch tree's "duplicate the odd node" rule: + # at each intermediate level the partial right-edge subtree is + # paired with itself before ascending one more level. + # ---------------------------------------------------------------- + while current_depth < target_depth: + accumulator = _combine(accumulator, accumulator) + current_depth += 1 + + # ---------------------------------------------------------------- + # Merge: the complete subtree in the frontier at target_depth + # sits to the LEFT of our right-edge accumulator. + # ---------------------------------------------------------------- + accumulator = _combine(self._frontier[target_depth], accumulator) + # After the merge current_depth advances by one more level + current_depth += 1 + + return accumulator.hex() + + # ------------------------------------------------------------------ + # Convenience / introspection + # ------------------------------------------------------------------ + + @property + def frontier_size(self) -> int: + """Number of nodes currently stored in the frontier. + + Equals ``bin(leaf_count).count('1')`` — one node per set bit in the + binary representation of the leaf count. This is the canonical public + way to observe frontier occupancy without accessing ``_frontier`` directly. + """ + return len(self._frontier) + + @property + def leaf_count(self) -> int: + """Total number of leaves that have been appended.""" + return self._leaf_count + + @property + def frontier_depth(self) -> int: + """Current maximum depth of the frontier (0 if empty).""" + return max(self._frontier.keys(), default=0) + + def __repr__(self) -> str: + return ( + f"IncrementalMerkleTree(" + f"leaves={self._leaf_count}, " + f"frontier_nodes={len(self._frontier)}, " + f"max_depth={self.frontier_depth}" + f")" + ) diff --git a/openverifiablellm/utils.py b/openverifiablellm/utils.py index ad9b7de..0397f43 100644 --- a/openverifiablellm/utils.py +++ b/openverifiablellm/utils.py @@ -10,7 +10,7 @@ import time import tracemalloc from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union import defusedxml.ElementTree as ET @@ -28,6 +28,11 @@ RE_WHITESPACE = re.compile(r"\s+") + +def _sha256_hex(data: bytes) -> str: + """Internal helper: return SHA-256 hex digest of raw bytes.""" + return hashlib.sha256(data).hexdigest() + # helpers: New helper to compute SHA256 and return raw bytes directly def compute_sha256_bytes( *, @@ -55,15 +60,43 @@ def compute_sha256_bytes( return sha256.digest() + # Merkle Tree Chunk-Level Hashing for Large Files def compute_merkle_root( - file_path: Union[str, Path], chunk_size: int = MERKLE_CHUNK_SIZE_BYTES + file_path: Union[str, Path, None] = None, + chunk_size: int = MERKLE_CHUNK_SIZE_BYTES, + *, + chunks: Optional[Iterable[bytes]] = None, ) -> str: + """ + Compute a Merkle root from a file path or an arbitrary iterable of byte chunks. + + Supports two modes: + - File mode (default): reads *file_path* in *chunk_size* chunks from disk. + - Streaming mode: pass ``chunks=`` to consume any byte iterable + (generator, network stream, list …) without touching the filesystem. + + Exactly one of *file_path* or *chunks* must be provided. + """ + if (file_path is None) == (chunks is None): + raise ValueError("Exactly one of 'file_path' or 'chunks' must be provided.") + if chunk_size <= 0: raise ValueError("chunk_size must be a positive integer") - path = Path(file_path) - leaves = [] + def _iter_chunks() -> Iterator[bytes]: + if chunks is not None: + yield from chunks + else: + path = Path(file_path) # type: ignore[arg-type] + with path.open("rb") as f: + while chunk := f.read(chunk_size): + yield chunk + + + leaves: List[bytes] = [] + for chunk in _iter_chunks(): + leaves.append(bytes.fromhex(_sha256_hex(chunk))) with path.open("rb") as f: while chunk := f.read(chunk_size): @@ -71,16 +104,19 @@ def compute_merkle_root( leaf_bytes = compute_sha256_bytes(data=chunk) leaves.append(leaf_bytes) + if not leaves: - return compute_sha256(data=b"") + return _sha256_hex(b"") while len(leaves) > 1: - next_level = [] + next_level: List[bytes] = [] for i in range(0, len(leaves), 2): left = leaves[i] right = leaves[i + 1] if i + 1 < len(leaves) else left - combined = left + right + + next_level.append(bytes.fromhex(_sha256_hex(combined))) + parent_bytes = compute_sha256_bytes(data=combined) next_level.append(parent_bytes) @@ -90,26 +126,46 @@ def compute_merkle_root( def generate_merkle_proof( - file_path: Union[str, Path], chunk_index: int, chunk_size: int = MERKLE_CHUNK_SIZE_BYTES -): + file_path: Union[str, Path, None] = None, + chunk_index: int = 0, + chunk_size: int = MERKLE_CHUNK_SIZE_BYTES, + *, + chunks: Optional[Iterable[bytes]] = None, +) -> List[Tuple[str, bool]]: """ Generate Merkle proof for a specific chunk index. + Supports two modes: + - File mode (default): reads *file_path* in *chunk_size* chunks. + - Streaming mode: pass ``chunks=`` to consume any byte iterable. + Returns: List of tuples (sibling_hash_hex, is_left) """ - path = Path(file_path) + if (file_path is None) == (chunks is None): + raise ValueError("Exactly one of 'file_path' or 'chunks' must be provided.") if chunk_size <= 0: raise ValueError("chunk_size must be a positive integer") - leaves = [] + leaves: List[bytes] = [] + + + if chunks is not None: + for chunk in chunks: + leaves.append(bytes.fromhex(_sha256_hex(chunk))) + else: + path = Path(file_path) # type: ignore[arg-type] + with path.open("rb") as f: + while chunk := f.read(chunk_size): + leaves.append(bytes.fromhex(_sha256_hex(chunk))) # Build leaf level with path.open("rb") as f: while chunk := f.read(chunk_size): leaf_bytes = compute_sha256_bytes(data=chunk) leaves.append(leaf_bytes) + if not leaves: raise ValueError("Cannot generate proof for empty file") @@ -117,11 +173,10 @@ def generate_merkle_proof( if chunk_index < 0 or chunk_index >= len(leaves): raise IndexError("Chunk index out of range") - proof = [] + proof: List[Tuple[str, bool]] = [] index = chunk_index while len(leaves) > 1: - # If odd number of nodes, duplicate last if len(leaves) % 2 == 1: leaves.append(leaves[-1]) @@ -131,12 +186,15 @@ def generate_merkle_proof( is_left = sibling_index < index proof.append((sibling.hex(), is_left)) - # Build next level - next_level = [] + next_level: List[bytes] = [] for i in range(0, len(leaves), 2): combined = leaves[i] + leaves[i + 1] + + next_level.append(bytes.fromhex(_sha256_hex(combined))) + parent_bytes = compute_sha256_bytes(data=combined) next_level.append(parent_bytes) + index //= 2 leaves = next_level @@ -149,7 +207,11 @@ def verify_merkle_proof(chunk_bytes: bytes, proof, merkle_root: str) -> bool: Verify a Merkle proof for given chunk bytes. """ try: + + current_hash = bytes.fromhex(_sha256_hex(chunk_bytes)) + current_hash = compute_sha256_bytes(data=chunk_bytes) + expected_root = bytes.fromhex(merkle_root) except (TypeError, ValueError): return False @@ -180,65 +242,104 @@ def verify_merkle_proof(chunk_bytes: bytes, proof, merkle_root: str) -> bool: else: combined = current_hash + sibling + + parent_hex = _sha256_hex(combined) + current_hash = bytes.fromhex(parent_hex) + current_hash = compute_sha256_bytes(data=combined) + return current_hash == expected_root # extract clean wikipage from actual wikipage -def extract_text_from_xml(input_path, *, write_manifest: bool = False): +def extract_text_from_xml( + input_path: Union[str, Path], + stream: bool = False, + *, + write_manifest: bool = False, +) -> Optional[Generator[str, None, None]]: """ Process a Wikipedia XML dump (compressed or uncompressed) into cleaned plain text. - Each element is parsed, its revision text is extracted, - cleaned using `clean_wikitext()`, and appended to a single - output text file. + Supports two modes controlled by the *stream* flag: + + - **Batch mode** (``stream=False``, default): writes all cleaned article + texts to ``data/processed/wiki_clean.txt``. Pass ``write_manifest=True`` + to also generate the dataset manifest. Returns ``None``. - The processed output is saved to: - data/processed/wiki_clean.txt + - **Streaming mode** (``stream=True``): returns a generator that yields + one cleaned plain-text string per Wikipedia article with O(1) memory + usage. No file is written and no manifest is generated. Parameters ---------- input_path : str or Path - Path to the Wikipedia XML dump file. - - Output - ------ - Creates: - data/processed/wiki_clean.txt + Path to the Wikipedia XML dump file (plain or bz2-compressed). + stream : bool + If True, return a generator instead of writing to disk. + write_manifest : bool + If True (batch mode only), generate ``data/dataset_manifest.json`` + after writing the processed file. """ input_path = Path(input_path) - # Fixed output path - project_root = Path.cwd() - output_dir = project_root / "data" / "processed" - output_dir.mkdir(parents=True, exist_ok=True) - - output_path = output_dir / "wiki_clean.txt" - - # Auto-detect file type using magic bytes separation with open(input_path, "rb") as test_f: is_bz2 = test_f.read(3) == b"BZh" open_func = bz2.open if is_bz2 else open + if stream: + + def _generator() -> Generator[str, None, None]: + with open_func(input_path, "rb") as f: + context = ET.iterparse(f, events=("end",)) + try: + for _, elem in context: + if elem.tag.endswith("page"): + text_elem = elem.find(".//{*}text") + raw_text: str = "" + if text_elem is not None and text_elem.text: + raw_text = text_elem.text + elem.clear() + if not raw_text: + continue + cleaned = clean_wikitext(raw_text) + if cleaned: + yield cleaned + finally: + # Release iterparse internal state even if the caller + # abandons the generator mid-stream or an exception occurs. + try: + context.close() + except AttributeError: + pass + logger.info("Finished streaming articles from '%s'.", input_path.name) + + return _generator() + + # Batch mode — write to file + project_root = Path.cwd() + output_dir = project_root / "data" / "processed" + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / "wiki_clean.txt" + with open_func(input_path, "rb") as f: context = ET.iterparse(f, events=("end",)) - with open(output_path, "w", encoding="utf-8") as out: for _, elem in context: if elem.tag.endswith("page"): text_elem = elem.find(".//{*}text") - if text_elem is not None and text_elem.text: cleaned = clean_wikitext(text_elem.text) if cleaned: out.write(cleaned + "\n\n") - elem.clear() + logger.info("Preprocessing complete. Output saved to %s", output_path) if write_manifest: generate_manifest(input_path, output_path) + return None # generate data manifest @@ -343,36 +444,56 @@ def verify_merkle_proof_from_file( return verify_merkle_proof(chunk_data, proof, expected_root) -# helpers:Update compute_sha256() to support bytes input directly. +# helpers: compute_sha256() supports bytes input directly and optional streaming. def compute_sha256( *, data: Optional[Union[bytes, bytearray]] = None, file_path: Optional[Union[str, Path]] = None, -) -> str: + stream: bool = False, +) -> Union[str, Generator[Tuple[bytes, str], None, None]]: """ - Compute SHA256 hash of a file OR raw bytes. - - This is used for both raw and processed files to ensure integrity. - This provides a deterministic fingerprint of the dataset, - enabling reproducibility and verification. - - Exactly one of `data` or `file_path` must be provided. + Compute SHA256 hash of a file OR raw bytes, with optional streaming support. + + Modes + ----- + - **data** mode: hash raw bytes in memory, return hex string. + - **file_path** mode (``stream=False``, default): hash the entire file, + return hex string. + - **file_path** mode (``stream=True``): return a generator that yields + ``(chunk_bytes, running_hex)`` pairs. The final ``running_hex`` equals + the SHA-256 of the whole file — same value as ``stream=False``. + + Exactly one of ``data`` or ``file_path`` must be provided. + ``stream=True`` is only valid with ``file_path``. """ - if (data is None) == (file_path is None): raise ValueError("Exactly one of 'data' or 'file_path' must be provided.") - sha256 = hashlib.sha256() + if stream and data is not None: + raise ValueError("stream=True is only valid with file_path, not data.") if data is not None: + sha256 = hashlib.sha256() sha256.update(data) return sha256.hexdigest() - path = Path(file_path) + path = Path(file_path) # type: ignore[arg-type] + + if stream: + + def _stream_gen() -> Generator[Tuple[bytes, str], None, None]: + _sha256 = hashlib.sha256() + with path.open("rb") as f: + while _chunk := f.read(8192): + _sha256.update(_chunk) + yield _chunk, _sha256.hexdigest() + + return _stream_gen() + + sha256 = hashlib.sha256() with path.open("rb") as f: while chunk := f.read(8192): sha256.update(chunk) - return sha256.hexdigest() diff --git a/scripts/benchmark.py b/scripts/benchmark.py new file mode 100644 index 0000000..49ef4c1 --- /dev/null +++ b/scripts/benchmark.py @@ -0,0 +1,49 @@ +""" +benchmark.py (scripts/) +======================== +CLI entry-point for the before-vs-after benchmark. + +This thin wrapper delegates all logic to +``openverifiablellm.benchmark.main``, keeping the scripts/ directory as +a collection of plain launchers with no duplicated implementation. + +Usage +----- + # Download the dump first (≈350 MB): + python scripts/download_dump.py --wiki simplewiki --date 20260201 + + # Then run the benchmark: + python scripts/benchmark.py simplewiki-20260201-pages-articles-multistream.xml.bz2 + + # Or, equivalently, via the package module: + python -m openverifiablellm.benchmark + +What it measures +---------------- +* **Old Way** (in-memory): decompress everything, collect all article texts + in a list, build a batch Merkle tree — O(N) RAM. +* **New Way** (streaming): yield one article at a time with + ``stream_text_from_xml``, feed each into an ``IncrementalMerkleTree`` + — O(log N) RAM. + +The script prints: +1. A terminal box-table with wall-clock time and peak RAM side-by-side. +2. A GitHub-Flavored Markdown table you can paste straight into the PR. +""" + +import sys +from pathlib import Path + +# --------------------------------------------------------------------------- +# Make sure the project root is on sys.path when the script is run directly +# (i.e. ``python scripts/benchmark.py``), even without an editable install. +# --------------------------------------------------------------------------- +_SCRIPTS_DIR = Path(__file__).resolve().parent +_PROJECT_ROOT = _SCRIPTS_DIR.parent +if str(_PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(_PROJECT_ROOT)) + +from openverifiablellm.benchmark import main # noqa: E402 + +if __name__ == "__main__": + main() diff --git a/tests/test_merkle.py b/tests/test_merkle.py new file mode 100644 index 0000000..f18a126 --- /dev/null +++ b/tests/test_merkle.py @@ -0,0 +1,460 @@ +""" +test_merkle.py +============== +pytest suite for IncrementalMerkleTree. + +The critical correctness property we verify +-------------------------------------------- +Given the same ordered sequence of N strings, a *batch* Merkle tree +(built by collecting all leaf hashes up-front, then reducing level by +level) and our *incremental* Merkle tree (appending one leaf at a time +via the Merkle Frontier) MUST produce byte-for-byte identical root hashes. + +If that invariant ever breaks, the streaming pipeline cannot be trusted +as a drop-in replacement for the legacy in-memory approach. + +Run with: + pip install -e ".[dev]" + pytest tests/test_merkle.py -v +""" + +import hashlib +import textwrap +from pathlib import Path +from typing import List, Optional + +import pytest + +from openverifiablellm.incremental_merkle import IncrementalMerkleTree +from openverifiablellm.utils import extract_text_from_xml + +# =========================================================================== +# Reference implementation: a classic batch Merkle tree +# =========================================================================== + + +def _sha256_bytes(data: bytes) -> bytes: + """Return raw 32-byte SHA-256 digest.""" + return hashlib.sha256(data).digest() + + +def _combine(left: bytes, right: bytes) -> bytes: + """Combine two 32-byte node hashes: parent = SHA-256(left || right).""" + return _sha256_bytes(left + right) + + +def batch_merkle_root(texts: List[str]) -> Optional[str]: + """ + Build a standard, batch Merkle tree from a list of strings and return + the root hash as a hex string. + + This is the canonical reference implementation used to verify the + incremental version. It stores ALL leaf hashes in memory and reduces + them level-by-level exactly as the legacy code does. + + Odd-number-of-nodes rule: when a level has an odd count, the last node + is duplicated so every parent has exactly two children. This matches + the rule used in ``IncrementalMerkleTree.get_root_hash()``. + + Parameters + ---------- + texts : list of str + Ordered list of article texts (or any strings). + + Returns + ------- + str | None + 64-char hex root hash, or ``None`` if *texts* is empty. + """ + if not texts: + return None + + # Leaf level: hash each string + level: List[bytes] = [_sha256_bytes(t.encode("utf-8")) for t in texts] + + # Reduce level-by-level until only the root remains + while len(level) > 1: + next_level: List[bytes] = [] + for i in range(0, len(level), 2): + left = level[i] + # Duplicate the last node if the level has an odd count + right = level[i + 1] if i + 1 < len(level) else left + next_level.append(_combine(left, right)) + level = next_level + + return level[0].hex() + + +# =========================================================================== +# Fixtures +# =========================================================================== + + +@pytest.fixture +def hundred_strings() -> List[str]: + """ + A deterministic list of 100 unique strings. + + Uses the f-string ``"article_{i:03d}"`` pattern so the content is + predictable and reproducible across test runs. + """ + return [f"article_{i:03d}: The quick brown fox jumps over the lazy dog." for i in range(100)] + + +# =========================================================================== +# Core correctness test (the PRIMARY deliverable) +# =========================================================================== + + +class TestIncrementalVsBatch: + """ + Verify that IncrementalMerkleTree produces the same root as the + batch reference implementation for the same input sequence. + """ + + def test_root_hash_matches_batch_100_strings(self, hundred_strings: List[str]) -> None: + """ + PRIMARY TEST: IncrementalMerkleTree root must exactly equal the + batch Merkle root for the same 100 strings. + + This is the definitive correctness gate for the streaming pipeline. + """ + # Build batch reference root + expected_root = batch_merkle_root(hundred_strings) + assert expected_root is not None, ( + "batch_merkle_root should not return None for non-empty input" + ) + + # Build incremental root + tree = IncrementalMerkleTree() + for text in hundred_strings: + tree.append_leaf(text) + + actual_root = tree.get_root_hash() + assert actual_root is not None, "IncrementalMerkleTree.get_root_hash() must not return None" + + assert actual_root == expected_root, ( + f"Root hash mismatch!\n" + f" Batch root : {expected_root}\n" + f" Incremental root: {actual_root}\n" + "The streaming pipeline is NOT a safe replacement for the " + "legacy in-memory approach until this test passes." + ) + + def test_root_hash_matches_batch_single_string(self) -> None: + """Single leaf: root must equal SHA-256 of that leaf's text.""" + texts = ["only one article here"] + expected = batch_merkle_root(texts) + + tree = IncrementalMerkleTree() + tree.append_leaf(texts[0]) + + assert tree.get_root_hash() == expected + + def test_root_hash_matches_batch_two_strings(self) -> None: + """Two leaves: tests the first combine() call in both implementations.""" + texts = ["alpha", "beta"] + expected = batch_merkle_root(texts) + + tree = IncrementalMerkleTree() + for t in texts: + tree.append_leaf(t) + + assert tree.get_root_hash() == expected + + def test_root_hash_matches_batch_power_of_two(self) -> None: + """ + Power-of-two leaf count (8 leaves): the tree is perfectly balanced; + the frontier should collapse to a single node (the root itself). + """ + texts = [f"leaf_{i}" for i in range(8)] + expected = batch_merkle_root(texts) + + tree = IncrementalMerkleTree() + for t in texts: + tree.append_leaf(t) + + # For exactly 2^k leaves the frontier collapses to a single node + assert tree.frontier_size == 1, ( + "For a power-of-two leaf count the frontier should collapse to 1 node" + ) + assert tree.get_root_hash() == expected + + def test_root_hash_matches_batch_odd_leaf_count(self) -> None: + """Odd leaf count (7): tests the 'duplicate last node' codepath.""" + texts = [f"article_{i}" for i in range(7)] + expected = batch_merkle_root(texts) + + tree = IncrementalMerkleTree() + for t in texts: + tree.append_leaf(t) + + assert tree.get_root_hash() == expected + + @pytest.mark.parametrize( + "n", [1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 17, 31, 32, 33, 64, 100, 128, 255, 256] + ) + def test_root_hash_matches_batch_parametric(self, n: int) -> None: + """ + Parametric sweep over various leaf counts including edge cases, + powers of two, and powers-of-two ± 1. + """ + texts = [f"string_{i:04d}" for i in range(n)] + expected = batch_merkle_root(texts) + + tree = IncrementalMerkleTree() + for t in texts: + tree.append_leaf(t) + + assert tree.get_root_hash() == expected, f"Root hash mismatch for n={n}" + + +# =========================================================================== +# Empty-tree behaviour +# =========================================================================== + + +class TestEmptyTree: + def test_get_root_hash_returns_none_when_empty(self) -> None: + """An empty tree has no root — get_root_hash() must return None.""" + tree = IncrementalMerkleTree() + assert tree.get_root_hash() is None + + def test_leaf_count_zero_when_empty(self) -> None: + tree = IncrementalMerkleTree() + assert tree.leaf_count == 0 + + def test_frontier_empty_when_empty(self) -> None: + tree = IncrementalMerkleTree() + assert tree.frontier_size == 0 + + +# =========================================================================== +# Leaf count tracking +# =========================================================================== + + +class TestLeafCount: + def test_leaf_count_increments(self) -> None: + tree = IncrementalMerkleTree() + for i in range(50): + tree.append_leaf(f"leaf_{i}") + assert tree.leaf_count == i + 1 + + def test_leaf_count_matches_hundred(self, hundred_strings: List[str]) -> None: + tree = IncrementalMerkleTree() + for t in hundred_strings: + tree.append_leaf(t) + assert tree.leaf_count == 100 + + +# =========================================================================== +# Frontier size invariant +# =========================================================================== + + +class TestFrontierInvariant: + """ + The frontier size equals the number of set bits in the binary + representation of the leaf count (popcount / Hamming weight). + """ + + @pytest.mark.parametrize("n", [1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 100, 128, 255, 256]) + def test_frontier_size_equals_popcount(self, n: int) -> None: + tree = IncrementalMerkleTree() + for i in range(n): + tree.append_leaf(f"x_{i}") + + expected_frontier_nodes = bin(n).count("1") + assert tree.frontier_size == expected_frontier_nodes, ( + f"After {n} leaves (binary: {bin(n)}), frontier should have " + f"{expected_frontier_nodes} node(s), got {tree.frontier_size}" + ) + + +# =========================================================================== +# Determinism / reproducibility +# =========================================================================== + + +class TestDeterminism: + def test_same_input_same_root(self, hundred_strings: List[str]) -> None: + """Two trees built from the same input must produce identical roots.""" + tree1 = IncrementalMerkleTree() + tree2 = IncrementalMerkleTree() + for t in hundred_strings: + tree1.append_leaf(t) + tree2.append_leaf(t) + assert tree1.get_root_hash() == tree2.get_root_hash() + + def test_different_order_different_root(self) -> None: + """Order matters: reversed input must produce a different root.""" + texts = [f"item_{i}" for i in range(10)] + + tree_fwd = IncrementalMerkleTree() + tree_rev = IncrementalMerkleTree() + for t in texts: + tree_fwd.append_leaf(t) + for t in reversed(texts): + tree_rev.append_leaf(t) + + assert tree_fwd.get_root_hash() != tree_rev.get_root_hash() + + def test_extra_leaf_changes_root(self) -> None: + """Appending one more leaf must change the root hash.""" + texts = [f"article_{i}" for i in range(10)] + + tree_a = IncrementalMerkleTree() + for t in texts: + tree_a.append_leaf(t) + root_a = tree_a.get_root_hash() + + tree_b = IncrementalMerkleTree() + for t in texts: + tree_b.append_leaf(t) + tree_b.append_leaf("one_more_article") + root_b = tree_b.get_root_hash() + + assert root_a != root_b + + def test_get_root_hash_is_non_destructive(self) -> None: + """Calling get_root_hash() multiple times must return the same value.""" + tree = IncrementalMerkleTree() + for i in range(20): + tree.append_leaf(f"leaf_{i}") + + roots = {tree.get_root_hash() for _ in range(5)} + assert len(roots) == 1, "get_root_hash() must be idempotent" + + +# =========================================================================== +# Hash format sanity checks +# =========================================================================== + + +class TestHashFormat: + def test_root_hash_is_64_char_hex(self) -> None: + """SHA-256 produces 32 bytes → 64 lowercase hex characters.""" + tree = IncrementalMerkleTree() + tree.append_leaf("hello world") + root = tree.get_root_hash() + assert root is not None + assert len(root) == 64 + assert root == root.lower() + # Must be valid hex + int(root, 16) + + def test_single_leaf_root_equals_sha256_of_text(self) -> None: + """ + For a single leaf the root hash must equal SHA-256(text.encode()). + There is no combining step — the leaf hash IS the root. + """ + text = "wikipedia article about Python" + expected = hashlib.sha256(text.encode("utf-8")).hexdigest() + + tree = IncrementalMerkleTree() + tree.append_leaf(text) + + assert tree.get_root_hash() == expected + + +# =========================================================================== +# repr() smoke test +# =========================================================================== + + +class TestRepr: + def test_repr_contains_leaf_count(self) -> None: + tree = IncrementalMerkleTree() + for i in range(7): + tree.append_leaf(f"t_{i}") + r = repr(tree) + assert "leaves=7" in r + assert "IncrementalMerkleTree" in r + + +# =========================================================================== +# Integration test: extract_text_from_xml(stream=True) + IncrementalMerkleTree +# =========================================================================== + +# Minimal MediaWiki-style XML dump with 3 articles. +_WIKI_XML_FIXTURE = textwrap.dedent("""\ + + + + Alpha + Alpha article text about [[cats]] and {{tmpl}}. + + + Beta + Beta article text about <ref>ref</ref> dogs. + + + Gamma + Gamma article text: simple plain text only. + + +""") + + +class TestStreamingXmlIntegration: + """ + Exercises extract_text_from_xml(..., stream=True) end-to-end by feeding + its output into IncrementalMerkleTree and comparing with the batch + reference implementation. + + Uses a tiny in-memory XML fixture so the test is fast and requires no + external files. + """ + + def _write_xml_fixture(self, tmp_path) -> Path: + """Write the XML fixture to a temp file and return its path.""" + p = tmp_path / "wiki_test.xml" + p.write_text(_WIKI_XML_FIXTURE, encoding="utf-8") + return p + + def test_streaming_root_matches_batch_root(self, tmp_path) -> None: + """ + extract_text_from_xml(stream=True) must yield the same articles as + iterating manually, and IncrementalMerkleTree fed from the stream + must produce the same root as the batch reference. + """ + xml_path = self._write_xml_fixture(tmp_path) + + # Collect streamed texts + gen = extract_text_from_xml(xml_path, stream=True) + assert gen is not None, "extract_text_from_xml must return a generator when stream=True" + streamed_texts = list(gen) + + assert len(streamed_texts) > 0, "Fixture must yield at least one article" + + # Batch reference root + expected_root = batch_merkle_root(streamed_texts) + assert expected_root is not None + + # Incremental root built by re-streaming + gen2 = extract_text_from_xml(xml_path, stream=True) + assert gen2 is not None + tree = IncrementalMerkleTree() + for text in gen2: + tree.append_leaf(text) + + assert tree.get_root_hash() == expected_root, ( + "IncrementalMerkleTree root does not match batch root " + "when fed from extract_text_from_xml(stream=True)" + ) + + def test_streaming_yields_cleaned_text(self, tmp_path) -> None: + """ + Streamed article texts must have wikitext markup removed + (no {{ }}, [[ ]], or tags). + """ + xml_path = self._write_xml_fixture(tmp_path) + gen = extract_text_from_xml(xml_path, stream=True) + assert gen is not None + texts = list(gen) + + for text in texts: + assert "{{" not in text, "Templates should be stripped" + assert "[[" not in text, "Links should be stripped" + assert "" not in text, "Ref tags should be stripped"