diff --git a/code_review_graph/cli.py b/code_review_graph/cli.py index 3fb4f93..88b4cce 100644 --- a/code_review_graph/cli.py +++ b/code_review_graph/cli.py @@ -59,12 +59,12 @@ def _print_banner() -> None: version = _get_version() # ANSI escape codes - c = "\033[36m" if color else "" # cyan — graph art - y = "\033[33m" if color else "" # yellow — center node - b = "\033[1m" if color else "" # bold - d = "\033[2m" if color else "" # dim - g = "\033[32m" if color else "" # green — commands - r = "\033[0m" if color else "" # reset + c = "\033[36m" if color else "" # cyan — graph art + y = "\033[33m" if color else "" # yellow — center node + b = "\033[1m" if color else "" # bold + d = "\033[2m" if color else "" # dim + g = "\033[32m" if color else "" # green — commands + r = "\033[0m" if color else "" # reset print(f""" {c} ●──●──●{r} @@ -151,74 +151,103 @@ def _handle_init(args: argparse.Namespace) -> None: print(" 2. Restart your AI coding tool to pick up the new config") +def _cli_post_process(store: object) -> None: + """Run post-build pipeline and print a summary line for each step.""" + from .postprocessing import run_post_processing + + pp = run_post_processing(store) # type: ignore[arg-type] + if pp.get("signatures_computed"): + print(f"Signatures: {pp['signatures_computed']} nodes") + if pp.get("fts_indexed"): + print(f"FTS indexed: {pp['fts_indexed']} nodes") + if pp.get("flows_detected") is not None: + print(f"Flows: {pp['flows_detected']}") + if pp.get("communities_detected") is not None: + print(f"Communities: {pp['communities_detected']}") + + def main() -> None: """Main CLI entry point.""" ap = argparse.ArgumentParser( prog="code-review-graph", description="Persistent incremental knowledge graph for code reviews", ) - ap.add_argument( - "-v", "--version", action="store_true", help="Show version and exit" - ) + ap.add_argument("-v", "--version", action="store_true", help="Show version and exit") sub = ap.add_subparsers(dest="command") # install (primary) + init (alias) - install_cmd = sub.add_parser( - "install", help="Register MCP server with AI coding platforms" - ) + install_cmd = sub.add_parser("install", help="Register MCP server with AI coding platforms") install_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") install_cmd.add_argument( - "--dry-run", action="store_true", + "--dry-run", + action="store_true", help="Show what would be done without writing files", ) install_cmd.add_argument( - "--no-skills", action="store_true", + "--no-skills", + action="store_true", help="Skip generating Claude Code skill files", ) install_cmd.add_argument( - "--no-hooks", action="store_true", + "--no-hooks", + action="store_true", help="Skip installing Claude Code hooks", ) # Legacy flags (kept for backwards compat, now no-ops since all is default) install_cmd.add_argument("--skills", action="store_true", help=argparse.SUPPRESS) install_cmd.add_argument("--hooks", action="store_true", help=argparse.SUPPRESS) - install_cmd.add_argument("--all", action="store_true", dest="install_all", - help=argparse.SUPPRESS) + install_cmd.add_argument( + "--all", action="store_true", dest="install_all", help=argparse.SUPPRESS + ) install_cmd.add_argument( "--platform", choices=[ - "claude", "claude-code", "cursor", "windsurf", "zed", - "continue", "opencode", "antigravity", "all", + "claude", + "claude-code", + "cursor", + "windsurf", + "zed", + "continue", + "opencode", + "antigravity", + "all", ], default="all", help="Target platform for MCP config (default: all detected)", ) - init_cmd = sub.add_parser( - "init", help="Alias for install" - ) + init_cmd = sub.add_parser("init", help="Alias for install") init_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") init_cmd.add_argument( - "--dry-run", action="store_true", + "--dry-run", + action="store_true", help="Show what would be done without writing files", ) init_cmd.add_argument( - "--no-skills", action="store_true", + "--no-skills", + action="store_true", help="Skip generating Claude Code skill files", ) init_cmd.add_argument( - "--no-hooks", action="store_true", + "--no-hooks", + action="store_true", help="Skip installing Claude Code hooks", ) init_cmd.add_argument("--skills", action="store_true", help=argparse.SUPPRESS) init_cmd.add_argument("--hooks", action="store_true", help=argparse.SUPPRESS) - init_cmd.add_argument("--all", action="store_true", dest="install_all", - help=argparse.SUPPRESS) + init_cmd.add_argument("--all", action="store_true", dest="install_all", help=argparse.SUPPRESS) init_cmd.add_argument( "--platform", choices=[ - "claude", "claude-code", "cursor", "windsurf", "zed", - "continue", "opencode", "antigravity", "all", + "claude", + "claude-code", + "cursor", + "windsurf", + "zed", + "continue", + "opencode", + "antigravity", + "all", ], default="all", help="Target platform for MCP config (default: all detected)", @@ -228,11 +257,13 @@ def main() -> None: build_cmd = sub.add_parser("build", help="Full graph build (re-parse all files)") build_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") build_cmd.add_argument( - "--skip-flows", action="store_true", + "--skip-flows", + action="store_true", help="Skip flow/community detection (signatures + FTS only)", ) build_cmd.add_argument( - "--skip-postprocess", action="store_true", + "--skip-postprocess", + action="store_true", help="Skip all post-processing (raw parse only)", ) @@ -241,11 +272,13 @@ def main() -> None: update_cmd.add_argument("--base", default="HEAD~1", help="Git diff base (default: HEAD~1)") update_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") update_cmd.add_argument( - "--skip-flows", action="store_true", + "--skip-flows", + action="store_true", help="Skip flow/community detection (signatures + FTS only)", ) update_cmd.add_argument( - "--skip-postprocess", action="store_true", + "--skip-postprocess", + action="store_true", help="Skip all post-processing (raw parse only)", ) @@ -277,7 +310,8 @@ def main() -> None: help="Rendering mode: auto (default), full, community, or file", ) vis_cmd.add_argument( - "--serve", action="store_true", + "--serve", + action="store_true", help="Start a local HTTP server to view the visualization (localhost:8765)", ) @@ -285,7 +319,8 @@ def main() -> None: wiki_cmd = sub.add_parser("wiki", help="Generate markdown wiki from community structure") wiki_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") wiki_cmd.add_argument( - "--force", action="store_true", + "--force", + action="store_true", help="Regenerate all pages even if content unchanged", ) @@ -308,9 +343,10 @@ def main() -> None: # eval eval_cmd = sub.add_parser("eval", help="Run evaluation benchmarks") eval_cmd.add_argument( - "--benchmark", default=None, + "--benchmark", + default=None, help="Comma-separated benchmarks to run (token_efficiency, impact_accuracy, " - "flow_completeness, search_quality, build_performance)", + "flow_completeness, search_quality, build_performance)", ) eval_cmd.add_argument("--repo", default=None, help="Comma-separated repo config names") eval_cmd.add_argument("--all", action="store_true", dest="run_all", help="Run all benchmarks") @@ -319,12 +355,8 @@ def main() -> None: # detect-changes detect_cmd = sub.add_parser("detect-changes", help="Analyze change impact") - detect_cmd.add_argument( - "--base", default="HEAD~1", help="Git diff base (default: HEAD~1)" - ) - detect_cmd.add_argument( - "--brief", action="store_true", help="Show brief summary only" - ) + detect_cmd.add_argument("--base", default="HEAD~1", help="Git diff base (default: HEAD~1)") + detect_cmd.add_argument("--brief", action="store_true", help="Show brief summary only") detect_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") # serve @@ -343,6 +375,7 @@ def main() -> None: if args.command == "serve": from .main import main as serve_main + serve_main(repo_root=args.repo) return @@ -351,9 +384,7 @@ def main() -> None: from .eval.runner import run_eval if getattr(args, "report", False): - output_dir = Path( - getattr(args, "output_dir", None) or "evaluate/results" - ) + output_dir = Path(getattr(args, "output_dir", None) or "evaluate/results") report = generate_full_report(output_dir) report_path = Path("evaluate/reports/summary.md") report_path.parent.mkdir(parents=True, exist_ok=True) @@ -365,9 +396,7 @@ def main() -> None: print(tables) else: repos = ( - [r.strip() for r in args.repo.split(",")] - if getattr(args, "repo", None) - else None + [r.strip() for r in args.repo.split(",")] if getattr(args, "repo", None) else None ) benchmarks = ( [b.strip() for b in args.benchmark.split(",")] @@ -439,6 +468,7 @@ def main() -> None: store = GraphStore(db_path) try: from .tools.build import run_postprocess + result = run_postprocess( flows=not getattr(args, "no_flows", False), communities=not getattr(args, "no_communities", False), @@ -475,32 +505,39 @@ def main() -> None: try: if args.command == "build": - pp = "none" if getattr(args, "skip_postprocess", False) else ( - "minimal" if getattr(args, "skip_flows", False) else "full" + pp = ( + "none" + if getattr(args, "skip_postprocess", False) + else ("minimal" if getattr(args, "skip_flows", False) else "full") ) from .tools.build import build_or_update_graph + result = build_or_update_graph( - full_rebuild=True, repo_root=str(repo_root), postprocess=pp, + full_rebuild=True, + repo_root=str(repo_root), + postprocess=pp, ) parsed = result.get("files_parsed", 0) nodes = result.get("total_nodes", 0) edges = result.get("total_edges", 0) - print( - f"Full build: {parsed} files, " - f"{nodes} nodes, {edges} edges" - f" (postprocess={pp})" - ) + print(f"Full build: {parsed} files, {nodes} nodes, {edges} edges (postprocess={pp})") if result.get("errors"): print(f"Errors: {len(result['errors'])}") + _cli_post_process(store) elif args.command == "update": - pp = "none" if getattr(args, "skip_postprocess", False) else ( - "minimal" if getattr(args, "skip_flows", False) else "full" + pp = ( + "none" + if getattr(args, "skip_postprocess", False) + else ("minimal" if getattr(args, "skip_flows", False) else "full") ) from .tools.build import build_or_update_graph + result = build_or_update_graph( - full_rebuild=False, repo_root=str(repo_root), - base=args.base, postprocess=pp, + full_rebuild=False, + repo_root=str(repo_root), + base=args.base, + postprocess=pp, ) updated = result.get("files_updated", 0) nodes = result.get("total_nodes", 0) @@ -510,6 +547,8 @@ def main() -> None: f"{nodes} nodes, {edges} edges" f" (postprocess={pp})" ) + if result.get("files_updated", 0) > 0: + _cli_post_process(store) elif args.command == "status": stats = store.get_stats() @@ -526,6 +565,7 @@ def main() -> None: if stored_sha: print(f"Built at commit: {stored_sha[:12]}") from .incremental import _git_branch_info + current_branch, current_sha = _git_branch_info(repo_root) if stored_branch and current_branch and stored_branch != current_branch: print( @@ -535,10 +575,13 @@ def main() -> None: ) elif args.command == "watch": - watch(repo_root, store) + from .postprocessing import run_post_processing + + watch(repo_root, store, on_files_updated=run_post_processing) elif args.command == "visualize": from .visualization import generate_html + html_path = repo_root / ".code-review-graph" / "graph.html" vis_mode = getattr(args, "mode", "auto") or "auto" generate_html(store, html_path, mode=vis_mode) @@ -565,6 +608,7 @@ def main() -> None: elif args.command == "wiki": from .wiki import generate_wiki + wiki_dir = repo_root / ".code-review-graph" / "wiki" result = generate_wiki(store, wiki_dir, force=args.force) total = result["pages_generated"] + result["pages_updated"] + result["pages_unchanged"] diff --git a/code_review_graph/incremental.py b/code_review_graph/incremental.py index 863211b..1b64e00 100644 --- a/code_review_graph/incremental.py +++ b/code_review_graph/incremental.py @@ -15,7 +15,7 @@ import subprocess import time from pathlib import Path -from typing import Optional +from typing import Callable, Optional from .graph import GraphStore from .parser import CodeParser @@ -144,8 +144,10 @@ def _git_branch_info(repo_root: Path) -> tuple[str, str]: try: result = subprocess.run( ["git", "rev-parse", "--abbrev-ref", "HEAD"], - capture_output=True, text=True, - cwd=str(repo_root), timeout=_GIT_TIMEOUT, + capture_output=True, + text=True, + cwd=str(repo_root), + timeout=_GIT_TIMEOUT, ) if result.returncode == 0: branch = result.stdout.strip() @@ -154,8 +156,10 @@ def _git_branch_info(repo_root: Path) -> tuple[str, str]: try: result = subprocess.run( ["git", "rev-parse", "HEAD"], - capture_output=True, text=True, - cwd=str(repo_root), timeout=_GIT_TIMEOUT, + capture_output=True, + text=True, + cwd=str(repo_root), + timeout=_GIT_TIMEOUT, ) if result.returncode == 0: sha = result.stdout.strip() @@ -163,6 +167,7 @@ def _git_branch_info(repo_root: Path) -> tuple[str, str]: pass return branch, sha + _SAFE_GIT_REF = re.compile(r"^[A-Za-z0-9_.~^/@{}\-]+$") @@ -244,11 +249,7 @@ def collect_all_files(repo_root: Path) -> list[str]: candidates = tracked else: # Fallback: walk directory - candidates = [ - str(p.relative_to(repo_root)) - for p in repo_root.rglob("*") - if p.is_file() - ] + candidates = [str(p.relative_to(repo_root)) for p in repo_root.rglob("*") if p.is_file()] for rel_path in candidates: if _should_ignore(rel_path, ignore_patterns): @@ -545,10 +546,22 @@ def incremental_update( _DEBOUNCE_SECONDS = 0.3 -def watch(repo_root: Path, store: GraphStore) -> None: +def watch( + repo_root: Path, + store: GraphStore, + on_files_updated: Optional[Callable] = None, +) -> None: """Watch for file changes and auto-update the graph. Uses a 300ms debounce to batch rapid-fire saves into a single update. + + Args: + repo_root: Repository root to watch. + store: Graph database to update. + on_files_updated: Optional callback invoked after each debounced + batch of file updates completes. Receives the store as its + only argument. Used by the CLI to run post-processing + (FTS, flows, communities) after watch updates. """ import threading @@ -612,9 +625,7 @@ def _schedule(self, abs_path: str): self._pending.add(abs_path) if self._timer is not None: self._timer.cancel() - self._timer = threading.Timer( - _DEBOUNCE_SECONDS, self._flush - ) + self._timer = threading.Timer(_DEBOUNCE_SECONDS, self._flush) self._timer.start() def _flush(self): @@ -624,33 +635,43 @@ def _flush(self): self._pending.clear() self._timer = None + updated = 0 for abs_path in paths: - self._update_file(abs_path) + if self._update_file(abs_path): + updated += 1 - def _update_file(self, abs_path: str): + if updated > 0 and on_files_updated is not None: + try: + on_files_updated(store) + except Exception as e: + logger.error("Post-update callback failed: %s", e) + + def _update_file(self, abs_path: str) -> bool: path = Path(abs_path) if not path.is_file(): - return + return False if path.is_symlink(): - return + return False if _is_binary(path): - return + return False try: source = path.read_bytes() fhash = hashlib.sha256(source).hexdigest() nodes, edges = parser.parse_bytes(path, source) store.store_file_nodes_edges(abs_path, nodes, edges, fhash) - store.set_metadata( - "last_updated", time.strftime("%Y-%m-%dT%H:%M:%S") - ) + store.set_metadata("last_updated", time.strftime("%Y-%m-%dT%H:%M:%S")) store.commit() rel = str(path.relative_to(repo_root)) logger.info( "Updated: %s (%d nodes, %d edges)", - rel, len(nodes), len(edges), + rel, + len(nodes), + len(edges), ) + return True except Exception as e: logger.error("Error updating %s: %s", abs_path, e) + return False handler = GraphUpdateHandler() observer = Observer() @@ -660,11 +681,10 @@ def _update_file(self, abs_path: str): logger.info("Watching %s for changes... (Ctrl+C to stop)", repo_root) try: import time as _time + while True: _time.sleep(1) except KeyboardInterrupt: observer.stop() observer.join() logger.info("Watch stopped.") - - diff --git a/code_review_graph/postprocessing.py b/code_review_graph/postprocessing.py new file mode 100644 index 0000000..c7dec59 --- /dev/null +++ b/code_review_graph/postprocessing.py @@ -0,0 +1,134 @@ +"""Shared post-build processing pipeline. + +After the core Tree-sitter parse (full_build or incremental_update), four +post-processing steps must run to populate derived tables: + +1. Compute node signatures +2. Rebuild FTS5 search index +3. Trace execution flows +4. Detect code communities + +This module extracts that pipeline so every entry point — MCP tool, CLI +commands, and watch mode — produces identical results. +""" + +from __future__ import annotations + +import logging +import sqlite3 +from typing import Any + +from .graph import GraphStore + +logger = logging.getLogger(__name__) + + +def run_post_processing(store: GraphStore) -> dict[str, Any]: + """Run all post-build steps on a populated graph. + + Each step is non-fatal: failures are logged and collected as warnings + so the primary build result is never lost. + + Args: + store: An open GraphStore with nodes and edges already populated. + + Returns: + Dict with keys for each step's result count and a ``warnings`` + list (only present when at least one step failed). + """ + result: dict[str, Any] = {} + warnings: list[str] = [] + + _compute_signatures(store, result, warnings) + _rebuild_fts_index(store, result, warnings) + _trace_flows(store, result, warnings) + _detect_communities(store, result, warnings) + + if warnings: + result["warnings"] = warnings + return result + + +# -- Individual steps (private) ------------------------------------------ + + +def _compute_signatures( + store: GraphStore, + result: dict[str, Any], + warnings: list[str], +) -> None: + """Compute human-readable signatures for nodes that lack one.""" + try: + rows = store.get_nodes_without_signature() + for row in rows: + node_id, name, kind, params, ret = ( + row[0], + row[1], + row[2], + row[3], + row[4], + ) + if kind in ("Function", "Test"): + sig = f"def {name}({params or ''})" + if ret: + sig += f" -> {ret}" + elif kind == "Class": + sig = f"class {name}" + else: + sig = name + store.update_node_signature(node_id, sig[:512]) + store.commit() + result["signatures_computed"] = len(rows) + except (sqlite3.OperationalError, TypeError, KeyError) as e: + logger.warning("Signature computation failed: %s", e) + warnings.append(f"Signature computation failed: {type(e).__name__}: {e}") + + +def _rebuild_fts_index( + store: GraphStore, + result: dict[str, Any], + warnings: list[str], +) -> None: + """Rebuild the FTS5 full-text search index.""" + try: + from .search import rebuild_fts_index + + fts_count = rebuild_fts_index(store) + result["fts_indexed"] = fts_count + except (sqlite3.OperationalError, ImportError) as e: + logger.warning("FTS index rebuild failed: %s", e) + warnings.append(f"FTS index rebuild failed: {type(e).__name__}: {e}") + + +def _trace_flows( + store: GraphStore, + result: dict[str, Any], + warnings: list[str], +) -> None: + """Trace execution flows from entry points.""" + try: + from .flows import store_flows, trace_flows + + flows = trace_flows(store) + count = store_flows(store, flows) + result["flows_detected"] = count + except (sqlite3.OperationalError, ImportError) as e: + logger.warning("Flow detection failed: %s", e) + warnings.append(f"Flow detection failed: {type(e).__name__}: {e}") + + +def _detect_communities( + store: GraphStore, + result: dict[str, Any], + warnings: list[str], +) -> None: + """Detect code communities via Leiden algorithm or file grouping.""" + try: + from .communities import detect_communities, store_communities + + comms = detect_communities(store) + count = store_communities(store, comms) + result["communities_detected"] = count + except (sqlite3.OperationalError, ImportError) as e: + logger.warning("Community detection failed: %s", e) + warnings.append(f"Community detection failed: {type(e).__name__}: {e}") diff --git a/code_review_graph/tools/build.py b/code_review_graph/tools/build.py index a8a7bd0..e24ae5d 100644 --- a/code_review_graph/tools/build.py +++ b/code_review_graph/tools/build.py @@ -8,10 +8,9 @@ from typing import Any from ..incremental import full_build, incremental_update +from ..postprocessing import run_post_processing from ._common import _get_store -logger = logging.getLogger(__name__) - def _run_postprocess( store: Any, @@ -38,7 +37,11 @@ def _run_postprocess( rows = store.get_nodes_without_signature() for row in rows: node_id, name, kind, params, ret = ( - row[0], row[1], row[2], row[3], row[4], + row[0], + row[1], + row[2], + row[3], + row[4], ) if kind in ("Function", "Test"): sig = f"def {name}({params or ''})" @@ -118,7 +121,8 @@ def _run_postprocess( warnings.append(f"Summary computation failed: {type(e).__name__}: {e}") store.set_metadata( - "last_postprocessed_at", time.strftime("%Y-%m-%dT%H:%M:%S"), + "last_postprocessed_at", + time.strftime("%Y-%m-%dT%H:%M:%S"), ) store.set_metadata("postprocess_level", postprocess) @@ -134,9 +138,7 @@ def _compute_summaries(store: Any) -> None: # -- community_summaries -- try: conn.execute("DELETE FROM community_summaries") - rows = conn.execute( - "SELECT id, name, size, dominant_language FROM communities" - ).fetchall() + rows = conn.execute("SELECT id, name, size, dominant_language FROM communities").fetchall() for r in rows: cid, cname, csize, clang = r[0], r[1], r[2], r[3] # Top 5 symbols by in+out edge count @@ -159,6 +161,7 @@ def _compute_summaries(store: Any) -> None: purpose = "" if paths: from os.path import commonprefix + prefix = commonprefix(paths) if "/" in prefix: purpose = prefix.rsplit("/", 1)[0].split("/")[-1] if "/" in prefix else "" @@ -187,7 +190,8 @@ def _compute_summaries(store: Any) -> None: fcount = r[5] # Get entry point name ep_row = conn.execute( - "SELECT qualified_name FROM nodes WHERE id = ?", (ep_id,), + "SELECT qualified_name FROM nodes WHERE id = ?", + (ep_id,), ).fetchone() ep_name = ep_row[0] if ep_row else str(ep_id) # Compress path to entry + top 3 intermediate + exit @@ -199,7 +203,8 @@ def _compute_summaries(store: Any) -> None: # Pick up to 3 intermediate nodes for nid in path_ids[1:4]: nr = conn.execute( - "SELECT name FROM nodes WHERE id = ?", (nid,), + "SELECT name FROM nodes WHERE id = ?", + (nid,), ).fetchone() if nr: critical_path.append(nr[0]) @@ -214,8 +219,7 @@ def _compute_summaries(store: Any) -> None: "INSERT OR REPLACE INTO flow_snapshots " "(flow_id, name, entry_point, critical_path, criticality, " "node_count, file_count) VALUES (?, ?, ?, ?, ?, ?, ?)", - (fid, fname, ep_name, _json.dumps(critical_path), - crit, ncount, fcount), + (fid, fname, ep_name, _json.dumps(critical_path), crit, ncount, fcount), ) except sqlite3.OperationalError: pass @@ -225,24 +229,32 @@ def _compute_summaries(store: Any) -> None: conn.execute("DELETE FROM risk_index") # Per-node risk: caller_count, test coverage, security keywords nodes = conn.execute( - "SELECT id, qualified_name, name FROM nodes " - "WHERE kind IN ('Function', 'Class', 'Test')" + "SELECT id, qualified_name, name FROM nodes WHERE kind IN ('Function', 'Class', 'Test')" ).fetchall() security_kw = { - "auth", "login", "password", "token", "session", "crypt", - "secret", "credential", "permission", "sql", "execute", + "auth", + "login", + "password", + "token", + "session", + "crypt", + "secret", + "credential", + "permission", + "sql", + "execute", } for n in nodes: nid, qn, name = n[0], n[1], n[2] # Count callers caller_count = conn.execute( - "SELECT COUNT(*) FROM edges WHERE target_qualified = ? " - "AND kind = 'CALLS'", (qn,), + "SELECT COUNT(*) FROM edges WHERE target_qualified = ? AND kind = 'CALLS'", + (qn,), ).fetchone()[0] # Test coverage tested = conn.execute( - "SELECT COUNT(*) FROM edges WHERE source_qualified = ? " - "AND kind = 'TESTED_BY'", (qn,), + "SELECT COUNT(*) FROM edges WHERE source_qualified = ? AND kind = 'TESTED_BY'", + (qn,), ).fetchone()[0] coverage = "tested" if tested > 0 else "untested" # Security relevance @@ -333,8 +345,11 @@ def build_or_update_graph( # Pass changed_files for incremental flow/community detection changed = result.get("changed_files") if not full_rebuild else None warnings = _run_postprocess( - store, build_result, postprocess, - full_rebuild=full_rebuild, changed_files=changed, + store, + build_result, + postprocess, + full_rebuild=full_rebuild, + changed_files=changed, ) if warnings: build_result["warnings"] = warnings @@ -369,12 +384,15 @@ def run_postprocess( warnings: list[str] = [] try: - # Signatures are always fast — run them try: rows = store.get_nodes_without_signature() for row in rows: node_id, name, kind, params, ret = ( - row[0], row[1], row[2], row[3], row[4], + row[0], + row[1], + row[2], + row[3], + row[4], ) if kind in ("Function", "Test"): sig = f"def {name}({params or ''})" @@ -430,7 +448,8 @@ def run_postprocess( warnings.append(f"Community detection failed: {type(e).__name__}: {e}") store.set_metadata( - "last_postprocessed_at", time.strftime("%Y-%m-%dT%H:%M:%S"), + "last_postprocessed_at", + time.strftime("%Y-%m-%dT%H:%M:%S"), ) result["summary"] = "Post-processing complete." if warnings: diff --git a/tests/test_postprocessing.py b/tests/test_postprocessing.py new file mode 100644 index 0000000..f9b0f94 --- /dev/null +++ b/tests/test_postprocessing.py @@ -0,0 +1,327 @@ +"""Tests for the shared post-processing pipeline.""" + +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +from code_review_graph.graph import GraphStore +from code_review_graph.incremental import full_build +from code_review_graph.parser import EdgeInfo, NodeInfo +from code_review_graph.postprocessing import run_post_processing + + +def _get_signature(store, qualified_name): + row = store._conn.execute( + "SELECT signature FROM nodes WHERE qualified_name = ?", + (qualified_name,), + ).fetchone() + return row["signature"] if row else None + + +class TestRunPostProcessing: + def setup_method(self): + self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + self.store = GraphStore(self.tmp.name) + self._seed_data() + + def teardown_method(self): + self.store.close() + Path(self.tmp.name).unlink(missing_ok=True) + + def _seed_data(self): + self.store.upsert_node( + NodeInfo( + kind="File", + name="/repo/app.py", + file_path="/repo/app.py", + line_start=1, + line_end=50, + language="python", + ) + ) + self.store.upsert_node( + NodeInfo( + kind="Class", + name="Service", + file_path="/repo/app.py", + line_start=5, + line_end=40, + language="python", + ) + ) + self.store.upsert_node( + NodeInfo( + kind="Function", + name="handle", + file_path="/repo/app.py", + line_start=10, + line_end=20, + language="python", + parent_name="Service", + params="request", + return_type="Response", + ) + ) + self.store.upsert_node( + NodeInfo( + kind="Function", + name="process", + file_path="/repo/app.py", + line_start=25, + line_end=35, + language="python", + ) + ) + self.store.upsert_node( + NodeInfo( + kind="Test", + name="test_handle", + file_path="/repo/test_app.py", + line_start=1, + line_end=10, + language="python", + is_test=True, + ) + ) + + self.store.upsert_edge( + EdgeInfo( + kind="CONTAINS", + source="/repo/app.py", + target="/repo/app.py::Service", + file_path="/repo/app.py", + ) + ) + self.store.upsert_edge( + EdgeInfo( + kind="CONTAINS", + source="/repo/app.py::Service", + target="/repo/app.py::Service.handle", + file_path="/repo/app.py", + ) + ) + self.store.upsert_edge( + EdgeInfo( + kind="CALLS", + source="/repo/app.py::Service.handle", + target="/repo/app.py::process", + file_path="/repo/app.py", + line=15, + ) + ) + self.store.commit() + + def test_computes_signatures(self): + unsigned = self.store.get_nodes_without_signature() + assert len(unsigned) > 0 + + result = run_post_processing(self.store) + + assert result["signatures_computed"] > 0 + remaining = self.store.get_nodes_without_signature() + assert len(remaining) == 0 + + def test_function_signature_format(self): + run_post_processing(self.store) + + sig = _get_signature(self.store, "/repo/app.py::Service.handle") + assert sig == "def handle(request) -> Response" + + def test_class_signature_format(self): + run_post_processing(self.store) + + sig = _get_signature(self.store, "/repo/app.py::Service") + assert sig == "class Service" + + def test_test_signature_format(self): + run_post_processing(self.store) + + sig = _get_signature(self.store, "/repo/test_app.py::test_handle") + assert sig is not None + assert sig.startswith("def test_handle(") + + def test_rebuilds_fts_index(self): + result = run_post_processing(self.store) + + assert "fts_indexed" in result + assert result["fts_indexed"] > 0 + + def test_fts_search_works_after_post_processing(self): + run_post_processing(self.store) + + from code_review_graph.search import hybrid_search + + hits = hybrid_search(self.store, "handle") + names = {h["name"] for h in hits} + assert "handle" in names + + def test_detects_flows(self): + result = run_post_processing(self.store) + + assert "flows_detected" in result + assert result["flows_detected"] >= 0 + + def test_detects_communities(self): + result = run_post_processing(self.store) + + assert "communities_detected" in result + assert result["communities_detected"] >= 0 + + def test_no_warnings_on_healthy_store(self): + result = run_post_processing(self.store) + + assert "warnings" not in result + + def test_empty_store_no_crash(self): + empty_tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + empty_store = GraphStore(empty_tmp.name) + try: + result = run_post_processing(empty_store) + assert result["signatures_computed"] == 0 + assert result["fts_indexed"] == 0 + finally: + empty_store.close() + Path(empty_tmp.name).unlink(missing_ok=True) + + def test_idempotent(self): + first = run_post_processing(self.store) + second = run_post_processing(self.store) + + assert second["fts_indexed"] == first["fts_indexed"] + assert second["signatures_computed"] == 0 + + def test_signature_truncated_at_512(self): + self.store.upsert_node( + NodeInfo( + kind="Function", + name="f", + file_path="/repo/big.py", + line_start=1, + line_end=2, + language="python", + params="a" * 600, + ) + ) + self.store.commit() + + run_post_processing(self.store) + sig = _get_signature(self.store, "/repo/big.py::f") + assert sig is not None + assert len(sig) <= 512 + + +class TestPostProcessingStepIsolation: + def setup_method(self): + self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + self.store = GraphStore(self.tmp.name) + self.store.upsert_node( + NodeInfo( + kind="Function", + name="fn", + file_path="/repo/a.py", + line_start=1, + line_end=5, + language="python", + ) + ) + self.store.commit() + + def teardown_method(self): + self.store.close() + Path(self.tmp.name).unlink(missing_ok=True) + + def test_fts_failure_does_not_block_flows(self): + with patch( + "code_review_graph.search.rebuild_fts_index", + side_effect=ImportError("fts boom"), + ): + result = run_post_processing(self.store) + + assert "flows_detected" in result + assert "communities_detected" in result + assert "warnings" in result + assert any("FTS" in w for w in result["warnings"]) + + def test_flow_failure_does_not_block_communities(self): + with patch( + "code_review_graph.flows.trace_flows", + side_effect=ImportError("flow boom"), + ): + result = run_post_processing(self.store) + + assert "communities_detected" in result + assert "warnings" in result + assert any("Flow" in w for w in result["warnings"]) + + def test_community_failure_still_has_signatures(self): + with patch( + "code_review_graph.communities.detect_communities", + side_effect=ImportError("comm boom"), + ): + result = run_post_processing(self.store) + + assert result["signatures_computed"] > 0 + assert "warnings" in result + assert any("Community" in w for w in result["warnings"]) + + +class TestToolBuildUsesSharedPipeline: + def test_build_tool_runs_post_processing(self, tmp_path): + py_file = tmp_path / "sample.py" + py_file.write_text("def hello():\n pass\n") + (tmp_path / ".git").mkdir() + (tmp_path / ".code-review-graph").mkdir() + + db_path = tmp_path / ".code-review-graph" / "graph.db" + store = GraphStore(db_path) + try: + mock_target = "code_review_graph.incremental.get_all_tracked_files" + with patch(mock_target, return_value=["sample.py"]): + full_build(tmp_path, store) + + unsigned_before_pp = store.get_nodes_without_signature() + run_post_processing(store) + unsigned_after_pp = store.get_nodes_without_signature() + + assert len(unsigned_before_pp) > 0 + assert len(unsigned_after_pp) == 0 + finally: + store.close() + + +class TestWatchCallbackIntegration: + def test_watch_accepts_callback_parameter(self): + import inspect + + from code_review_graph.incremental import watch + + sig = inspect.signature(watch) + assert "on_files_updated" in sig.parameters + + def test_watch_callback_not_called_without_updates(self, tmp_path): + import threading + + from code_review_graph.incremental import watch + + (tmp_path / ".git").mkdir() + db_path = tmp_path / "test.db" + store = GraphStore(db_path) + callback = MagicMock() + + try: + + def run_watch(): + try: + watch(tmp_path, store, on_files_updated=callback) + except KeyboardInterrupt: + pass + + t = threading.Thread(target=run_watch, daemon=True) + t.start() + + import time + + time.sleep(0.5) + callback.assert_not_called() + finally: + store.close()