diff --git a/.env.template b/.env.template index b5047c95..4d63e8ee 100644 --- a/.env.template +++ b/.env.template @@ -2,6 +2,15 @@ FALKORDB_HOST=localhost FALKORDB_PORT=6379 +# Database backend. Use "lite" to run against an embedded FalkorDBLite +# instance instead of an external FalkorDB host/port. +CODE_GRAPH_DB_BACKEND=falkordb +FALKORDB_LITE_PATH=~/.cache/code-graph/falkordblite.rdb +# Optional: expose FalkorDBLite on localhost for host/port-only integrations +# such as GraphRAG chat. Structural CodeGraph/MCP tools do not need this. +# FALKORDB_LITE_HOST=127.0.0.1 +# FALKORDB_LITE_PORT=6379 + # Optional FalkorDB authentication FALKORDB_USERNAME= FALKORDB_PASSWORD= diff --git a/README.md b/README.md index e2e54bb0..a16b668b 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ code-graph/ - Python `>=3.12,<3.14` - Node.js 20+ - [`uv`](https://docs.astral.sh/uv/) -- A FalkorDB instance (local or cloud) +- A FalkorDB instance (local/cloud) or the optional FalkorDBLite backend ### 1. Start FalkorDB @@ -68,6 +68,16 @@ code-graph/ docker run -p 6379:6379 -it --rm falkordb/falkordb ``` +**Option C:** Use embedded FalkorDBLite: + +```bash +uv sync --extra light +export CODE_GRAPH_DB_BACKEND=lite +export FALKORDB_LITE_PATH=~/.cache/code-graph/falkordblite.rdb +``` + +FalkorDBLite runs a local embedded server over a private Unix socket by default. Set `FALKORDB_LITE_PORT` only when a host/port-only integration, such as GraphRAG chat, must connect to the embedded database. + ### 2. Configure environment variables Copy the template and adjust it for your setup: @@ -78,10 +88,14 @@ cp .env.template .env | Variable | Description | Required | Default | |----------|-------------|----------|---------| +| `CODE_GRAPH_DB_BACKEND` | Database backend: `falkordb` or `lite` | No | `falkordb` | | `FALKORDB_HOST` | FalkorDB hostname | No | `localhost` | | `FALKORDB_PORT` | FalkorDB port | No | `6379` | | `FALKORDB_USERNAME` | Optional FalkorDB username | No | empty | | `FALKORDB_PASSWORD` | Optional FalkorDB password | No | empty | +| `FALKORDB_LITE_PATH` | FalkorDBLite database file path | No | `~/.cache/code-graph/falkordblite.rdb` | +| `FALKORDB_LITE_HOST` | Host used when exposing FalkorDBLite over TCP | No | `127.0.0.1` | +| `FALKORDB_LITE_PORT` | Optional TCP port for FalkorDBLite host/port clients | No | empty | | `SECRET_TOKEN` | Token checked by protected endpoints | No | empty | | `CODE_GRAPH_PUBLIC` | Set `1` to skip auth on read-only endpoints | No | `0` | | `ALLOWED_ANALYSIS_DIR` | Root path allowed for `/api/analyze_folder` | No | repository root | diff --git a/api/cli.py b/api/cli.py index bf6d9d46..f47d09b8 100644 --- a/api/cli.py +++ b/api/cli.py @@ -62,6 +62,18 @@ def _check_connection(host: str, port: int) -> bool: def ensure_db() -> None: """Ensure FalkorDB is running, auto-starting a Docker container if needed.""" + from .db import create_falkordb, is_lite_backend + + if is_lite_backend(): + try: + db = create_falkordb() + db.connection.ping() + except Exception as e: + _json_error(f"Failed to initialize FalkorDBLite: {e}") + _stderr("FalkorDBLite embedded backend is ready") + _json_out({"status": "ok", "backend": "lite"}) + return + host = os.getenv("FALKORDB_HOST", "localhost") try: port = int(os.getenv("FALKORDB_PORT", "6379")) diff --git a/api/db.py b/api/db.py new file mode 100644 index 00000000..8729b097 --- /dev/null +++ b/api/db.py @@ -0,0 +1,157 @@ +"""Database backend selection for CodeGraph. + +The default backend is a regular FalkorDB server addressed by host/port. +Set ``CODE_GRAPH_DB_BACKEND=lite`` to use FalkorDBLite's embedded server. +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any + + +LITE_BACKENDS = {"lite", "falkordblite"} +DEFAULT_LITE_PATH = "~/.cache/code-graph/falkordblite.rdb" + + +def is_lite_backend() -> bool: + """Return whether CodeGraph should use the embedded FalkorDBLite backend.""" + backend = os.getenv("CODE_GRAPH_DB_BACKEND", "falkordb").strip().lower() + return backend in LITE_BACKENDS + + +def _lite_db_path() -> str: + path = Path(os.getenv("FALKORDB_LITE_PATH", DEFAULT_LITE_PATH)).expanduser() + path.parent.mkdir(parents=True, exist_ok=True) + return str(path) + + +def _lite_serverconfig() -> dict[str, str]: + """Return FalkorDBLite server config from env. + + FalkorDBLite defaults to a private Unix socket. Supplying + ``FALKORDB_LITE_PORT`` additionally exposes a local TCP port, which is + useful for libraries that only accept host/port connection settings. + """ + port = os.getenv("FALKORDB_LITE_PORT") + if not port: + return {} + return { + "bind": os.getenv("FALKORDB_LITE_HOST", "127.0.0.1"), + "port": port, + } + + +def create_falkordb() -> Any: + """Create a sync FalkorDB client for the configured backend.""" + if is_lite_backend(): + try: + from redislite.falkordb_client import FalkorDB as LiteFalkorDB + except ImportError as e: + raise RuntimeError( + "CODE_GRAPH_DB_BACKEND=lite requires the optional " + "`falkordblite` dependency. Install with " + "`uv sync --extra light` or `pip install 'falkordb-code-graph[light]'`." + ) from e + + return LiteFalkorDB(_lite_db_path(), serverconfig=_lite_serverconfig()) + + from falkordb import FalkorDB + + return FalkorDB( + host=os.getenv("FALKORDB_HOST", "localhost"), + port=os.getenv("FALKORDB_PORT", 6379), + username=os.getenv("FALKORDB_USERNAME", None), + password=os.getenv("FALKORDB_PASSWORD", None), + ) + + +def create_async_falkordb() -> Any: + """Create an async FalkorDB client for the configured backend.""" + if is_lite_backend(): + try: + from redislite.async_falkordb_client import AsyncFalkorDB as LiteAsyncFalkorDB + except ImportError as e: + raise RuntimeError( + "CODE_GRAPH_DB_BACKEND=lite requires the optional " + "`falkordblite` dependency. Install with " + "`uv sync --extra light` or `pip install 'falkordb-code-graph[light]'`." + ) from e + + client = LiteAsyncFalkorDB(_lite_db_path(), serverconfig=_lite_serverconfig()) + if not hasattr(client, "aclose"): + client.aclose = client.close + return client + + from falkordb.asyncio import FalkorDB as AsyncFalkorDB + + return AsyncFalkorDB( + host=os.getenv("FALKORDB_HOST", "localhost"), + port=int(os.getenv("FALKORDB_PORT", 6379)), + username=os.getenv("FALKORDB_USERNAME", None), + password=os.getenv("FALKORDB_PASSWORD", None), + ) + + +def create_redis_connection() -> Any: + """Create a sync Redis-compatible connection for metadata operations.""" + if is_lite_backend(): + return create_falkordb().connection + + import redis + + return redis.Redis( + host=os.getenv("FALKORDB_HOST", "localhost"), + port=int(os.getenv("FALKORDB_PORT", "6379")), + username=os.getenv("FALKORDB_USERNAME"), + password=os.getenv("FALKORDB_PASSWORD"), + decode_responses=True, + ) + + +def create_async_redis_connection() -> Any: + """Create an async Redis-compatible connection for metadata operations.""" + if is_lite_backend(): + return create_async_falkordb().connection + + import redis.asyncio as aioredis + + return aioredis.Redis( + host=os.getenv("FALKORDB_HOST", "localhost"), + port=int(os.getenv("FALKORDB_PORT", "6379")), + username=os.getenv("FALKORDB_USERNAME"), + password=os.getenv("FALKORDB_PASSWORD"), + decode_responses=True, + ) + + +def graphrag_connection_kwargs() -> dict[str, Any]: + """Return host/port kwargs for GraphRAG SDK constructors. + + GraphRAG SDK accepts host/port only. FalkorDBLite's private Unix socket is + therefore usable for structural tools but not for GraphRAG unless a local + TCP port is explicitly enabled with ``FALKORDB_LITE_PORT``. + """ + if is_lite_backend(): + port = os.getenv("FALKORDB_LITE_PORT") + if not port: + raise RuntimeError( + "GraphRAG requires host/port access. When using " + "CODE_GRAPH_DB_BACKEND=lite, set FALKORDB_LITE_PORT to expose " + "the embedded FalkorDBLite instance on localhost." + ) + create_falkordb() + return { + "host": os.getenv("FALKORDB_LITE_HOST", "127.0.0.1"), + "port": int(port), + "username": None, + "password": None, + } + + return { + "host": os.getenv("FALKORDB_HOST", "localhost"), + "port": int(os.getenv("FALKORDB_PORT", 6379)), + "username": os.getenv("FALKORDB_USERNAME", None), + "password": os.getenv("FALKORDB_PASSWORD", None), + } diff --git a/api/git_utils/git_graph.py b/api/git_utils/git_graph.py index 52de8da9..ec41eec0 100644 --- a/api/git_utils/git_graph.py +++ b/api/git_utils/git_graph.py @@ -1,9 +1,8 @@ -import os import logging -from falkordb import FalkorDB, Node -from falkordb.asyncio import FalkorDB as AsyncFalkorDB +from falkordb import Node from typing import List, Optional +from api.db import create_async_falkordb, create_falkordb from pygit2 import Commit # Configure logging @@ -19,10 +18,7 @@ class GitGraph(): def __init__(self, name: str): - self.db = FalkorDB(host=os.getenv('FALKORDB_HOST', 'localhost'), - port=os.getenv('FALKORDB_PORT', 6379), - username=os.getenv('FALKORDB_USERNAME', None), - password=os.getenv('FALKORDB_PASSWORD', None)) + self.db = create_falkordb() self.g = self.db.select_graph(name) @@ -182,12 +178,7 @@ class AsyncGitGraph: """Async read-only git graph for endpoint use.""" def __init__(self, name: str): - self.db = AsyncFalkorDB( - host=os.getenv('FALKORDB_HOST', 'localhost'), - port=int(os.getenv('FALKORDB_PORT', 6379)), - username=os.getenv('FALKORDB_USERNAME', None), - password=os.getenv('FALKORDB_PASSWORD', None), - ) + self.db = create_async_falkordb() self.g = self.db.select_graph(name) def _commit_from_node(self, node: Node) -> dict: @@ -205,4 +196,3 @@ async def list_commits(self) -> List[dict]: async def close(self) -> None: await self.db.aclose() - diff --git a/api/graph.py b/api/graph.py index 7d14f957..ffa208fb 100644 --- a/api/graph.py +++ b/api/graph.py @@ -1,10 +1,9 @@ -import os import re import time -from .entities import * from typing import Optional -from falkordb import FalkorDB, Path, Node, QueryResult -from falkordb.asyncio import FalkorDB as AsyncFalkorDB +from falkordb import Path, Node, QueryResult +from .db import create_async_falkordb, create_falkordb +from .entities import File, encode_edge, encode_node # Configure the logger import logging @@ -62,10 +61,7 @@ def parse_graph_name(graph_name: str) -> Optional[tuple[str, str]]: def graph_exists(name: str): - db = FalkorDB(host=os.getenv('FALKORDB_HOST', 'localhost'), - port=os.getenv('FALKORDB_PORT', 6379), - username=os.getenv('FALKORDB_USERNAME', None), - password=os.getenv('FALKORDB_PASSWORD', None)) + db = create_falkordb() return name in db.list_graphs() @@ -86,10 +82,7 @@ def get_repos() -> list[dict]: single graph until the migration is run. """ - db = FalkorDB(host=os.getenv('FALKORDB_HOST', 'localhost'), - port=os.getenv('FALKORDB_PORT', 6379), - username=os.getenv('FALKORDB_USERNAME', None), - password=os.getenv('FALKORDB_PASSWORD', None)) + db = create_falkordb() repos = [] for g in db.list_graphs(): @@ -140,10 +133,7 @@ def __init__(self, name: str, branch: Optional[str] = None) -> None: self.branch = branch or DEFAULT_BRANCH self.name = compose_graph_name(self.project, self.branch) - self.db = FalkorDB(host=os.getenv('FALKORDB_HOST', 'localhost'), - port=os.getenv('FALKORDB_PORT', 6379), - username=os.getenv('FALKORDB_USERNAME', None), - password=os.getenv('FALKORDB_PASSWORD', None)) + self.db = create_falkordb() self.g = self.db.select_graph(self.name) # Initialize the backlog as disabled by default @@ -180,10 +170,7 @@ def from_raw_name(cls, raw_name: str) -> "Graph": obj.branch = DEFAULT_BRANCH else: obj.project, obj.branch = parsed - obj.db = FalkorDB(host=os.getenv('FALKORDB_HOST', 'localhost'), - port=os.getenv('FALKORDB_PORT', 6379), - username=os.getenv('FALKORDB_USERNAME', None), - password=os.getenv('FALKORDB_PASSWORD', None)) + obj.db = create_falkordb() obj.g = obj.db.select_graph(raw_name) obj.backlog = None return obj @@ -297,7 +284,7 @@ def _query(self, q: str, params: Optional[dict] = None) -> QueryResult: return result_set - def get_sub_graph(self, l: int) -> dict: + def get_sub_graph(self, limit: int) -> dict: q = """MATCH (src) OPTIONAL MATCH (src)-[e]->(dest) @@ -306,7 +293,7 @@ def get_sub_graph(self, l: int) -> dict: sub_graph = {'nodes': [], 'edges': [] } - result_set = self._query(q, {'limit': l}).result_set + result_set = self._query(q, {'limit': limit}).result_set for row in result_set: src = row[0] e = row[1] @@ -592,7 +579,7 @@ def set_file_coverage(self, path: str, name: str, ext: str, coverage: float) -> params = {'path': path, 'name': name, 'ext': ext, 'coverage': coverage} - res = self._query(q, params) + self._query(q, params) def connect_entities(self, relation: str, src_id: int, dest_id: int, properties: dict = {}) -> None: """ @@ -789,14 +776,9 @@ def unreachable_entities(self, lbl: Optional[str], rel: Optional[str]) -> list[d # Async helpers and read-only async graph wrapper # --------------------------------------------------------------------------- -def _async_db() -> AsyncFalkorDB: +def _async_db(): """Create an async FalkorDB connection using environment config.""" - return AsyncFalkorDB( - host=os.getenv('FALKORDB_HOST', 'localhost'), - port=int(os.getenv('FALKORDB_PORT', 6379)), - username=os.getenv('FALKORDB_USERNAME', None), - password=os.getenv('FALKORDB_PASSWORD', None), - ) + return create_async_falkordb() async def async_graph_exists(name: str) -> bool: @@ -952,4 +934,3 @@ async def stats(self) -> dict: async def close(self) -> None: await self.db.aclose() - diff --git a/api/info.py b/api/info.py index f2173fda..75e5ef28 100644 --- a/api/info.py +++ b/api/info.py @@ -1,9 +1,9 @@ -import os import redis import redis.asyncio as aioredis import logging from typing import Optional, Dict +from .db import create_async_redis_connection, create_redis_connection from .graph import DEFAULT_BRANCH # Configure logging @@ -39,13 +39,7 @@ def get_redis_connection() -> redis.Redis: redis.Redis: A Redis connection object. """ try: - return redis.Redis( - host = os.getenv('FALKORDB_HOST', "localhost"), - port = int(os.getenv('FALKORDB_PORT', "6379")), - username = os.getenv('FALKORDB_USERNAME'), - password = os.getenv('FALKORDB_PASSWORD'), - decode_responses = True # To ensure string responses - ) + return create_redis_connection() except Exception as e: logging.error(f"Error connecting to Redis: {e}") raise @@ -160,13 +154,7 @@ def get_repo_info(repo_name: str, branch: Optional[str] = None) -> Optional[Dict # --------------------------------------------------------------------------- async def async_get_redis_connection() -> aioredis.Redis: - return aioredis.Redis( - host=os.getenv('FALKORDB_HOST', "localhost"), - port=int(os.getenv('FALKORDB_PORT', "6379")), - username=os.getenv('FALKORDB_USERNAME'), - password=os.getenv('FALKORDB_PASSWORD'), - decode_responses=True, - ) + return create_async_redis_connection() async def async_get_repo_info(repo_name: str, branch: Optional[str] = None) -> Optional[Dict[str, str]]: @@ -187,4 +175,3 @@ async def async_get_repo_info(repo_name: str, branch: Optional[str] = None) -> O except Exception as e: logging.error(f"Error retrieving repo info for '{repo_name}': {e}") raise - diff --git a/api/llm.py b/api/llm.py index 25bb5d3d..0e64188a 100644 --- a/api/llm.py +++ b/api/llm.py @@ -3,6 +3,7 @@ import logging from typing import Optional +from .db import graphrag_connection_kwargs from graphrag_sdk.models.litellm import LiteModel from graphrag_sdk import ( Ontology, @@ -255,10 +256,7 @@ def _create_kg_agent(repo_name: str, branch: Optional[str] = None): name=graph_name, ontology=ontology, model_config=KnowledgeGraphModelConfig.with_model(model), - host=os.getenv('FALKORDB_HOST', 'localhost'), - port=os.getenv('FALKORDB_PORT', 6379), - username=os.getenv('FALKORDB_USERNAME', None), - password=os.getenv('FALKORDB_PASSWORD', None), + **graphrag_connection_kwargs(), cypher_system_instruction=CYPHER_GEN_SYSTEM, qa_system_instruction=GRAPH_QA_SYSTEM, cypher_gen_prompt=CYPHER_GEN_PROMPT, diff --git a/api/mcp/tools/__init__.py b/api/mcp/tools/__init__.py index 87b8b3a0..2148c100 100644 --- a/api/mcp/tools/__init__.py +++ b/api/mcp/tools/__init__.py @@ -4,4 +4,11 @@ ``api.mcp.server``. Import this package to register all tools. """ +# NOTE: the GraphRAG-backed ``ask`` tool is intentionally NOT part of the MCP +# surface. In benchmarking it failed on every call ("Missing GOOGLE_API_KEY/ +# GEMINI_API_KEY") because the spawned MCP server env carries only FalkorDB +# coordinates, and even when keyed it returned File-level fuzzy matches rather +# than the structural answers the nav workflow needs. Burning an LLM round-trip +# per call for no signal, so we expose only the deterministic structural tools. +# GraphRAG ask remains available on the HTTP /api/chat path. from . import structural # noqa: F401 (registers tools on import) diff --git a/api/mcp/tools/structural.py b/api/mcp/tools/structural.py index 2a7a7ade..757d1cf1 100644 --- a/api/mcp/tools/structural.py +++ b/api/mcp/tools/structural.py @@ -19,8 +19,10 @@ import asyncio import logging +import math import os import re +from collections import Counter, defaultdict from pathlib import Path from typing import Any, Optional @@ -30,6 +32,284 @@ logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Hybrid ranking (search_code Step 3) — tunable weights +# --------------------------------------------------------------------------- +# Repo-wide file relevance = weighted sum of normalized component scores minus +# a path penalty. Each component is min-max normalized across files so the +# weights are directly comparable. Tune these after the smoke test. +_HYBRID_W_NAME = 2.0 # symbol name == a query identifier (exact) +_HYBRID_W_PATH = 1.5 # query tokens present in the file's path +_HYBRID_W_BM25 = 1.5 # BM25 over file body (symbol names + docstrings + path) +_HYBRID_W_CENT = 0.5 # log(1 + cross-file in-degree): structural centrality +_HYBRID_W_PEN = 1.0 # penalty for test/legacy/vendored/etc. paths +_HYBRID_BODY_TOKEN_CAP = 4000 +_HYBRID_MIN_IDENT_LEN = 4 +_HYBRID_BM25_K1 = 1.5 +_HYBRID_BM25_B = 0.75 + +_WORD_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_]*") +_CAMEL_RE = re.compile(r"[A-Z]+(?=[A-Z][a-z])|[A-Z]?[a-z0-9]+|[A-Z]+") +_PENALTY_RE = re.compile( + r"(^|/)(tests?|testing|conftest|migrations?|legacy|vendor|vendored|" + r"third_party|__pycache__|examples?|docs?|benchmarks?)(/|$)|_test\.py$|(^|/)test_", + re.IGNORECASE, +) + + +def _subtokens(ident: str) -> list[str]: + """Split an identifier into snake/camel sub-tokens (lowercased) plus itself.""" + out: list[str] = [] + for part in ident.split("_"): + if not part: + continue + out.extend(m.group(0).lower() for m in _CAMEL_RE.finditer(part)) + out.append(ident.lower()) + return [t for t in out if t] + + +def _tokenize(text: str) -> list[str]: + toks: list[str] = [] + for m in _WORD_RE.finditer(text or ""): + toks.extend(_subtokens(m.group(0))) + return toks + + +def _issue_identifiers(text: str) -> set[str]: + """Candidate symbol names mentioned in the query (backticked or code-shaped).""" + cands: set[str] = set() + for m in re.finditer(r"`([^`]+)`", text or ""): + for ident in _WORD_RE.findall(m.group(1)): + if len(ident) >= _HYBRID_MIN_IDENT_LEN: + cands.add(ident.lower()) + for ident in _WORD_RE.findall(text or ""): + if len(ident) >= _HYBRID_MIN_IDENT_LEN and ( + "_" in ident or re.search(r"[a-z][A-Z]", ident) or ident[0].isupper() + ): + cands.add(ident.lower()) + return cands + + +def _minmax(d: dict[str, float]) -> dict[str, float]: + if not d: + return {} + vals = list(d.values()) + lo, hi = min(vals), max(vals) + if hi - lo < 1e-12: + return {k: 0.0 for k in d} + return {k: (v - lo) / (hi - lo) for k, v in d.items()} + + +def _rep_key(d: dict[str, Any]) -> tuple: + """Stable ordering key for a file's representative-symbol candidate. + + Lower is better: exact query-id match first, then lowest ``src_start``, + then ``name`` then ``src_end`` as deterministic tie-breakers so the chosen + representative never depends on FalkorDB row order (even when ``src_start`` + ties or is missing). + """ + return ( + 0 if d.get("exact") else 1, + d["src_start"] if d.get("src_start") is not None else math.inf, + d.get("name") or "", + d["src_end"] if d.get("src_end") is not None else math.inf, + ) + + +def _bm25(query_tokens: set[str], files: list[str], + tokmap: dict[str, list[str]]) -> dict[str, float]: + docs = [tokmap.get(f, []) for f in files] + n = len(docs) + if n == 0: + return {} + df: Counter = Counter() + for d in docs: + for t in set(d): + df[t] += 1 + avgdl = (sum(len(d) for d in docs) / n) or 1.0 + idf = {t: math.log(1 + (n - k + 0.5) / (k + 0.5)) for t, k in df.items()} + k1, b = _HYBRID_BM25_K1, _HYBRID_BM25_B + out: dict[str, float] = {} + for f, d in zip(files, docs): + if not d: + out[f] = 0.0 + continue + tf = Counter(d) + dl = len(d) + s = 0.0 + for t in query_tokens: + freq = tf.get(t) + if not freq: + continue + s += idf.get(t, 0.0) * (freq * (k1 + 1)) / ( + freq + k1 * (1 - b + b * dl / avgdl) + ) + out[f] = s + return out + + +async def _hybrid_rank(g, query: str, project: Optional[str]) -> list[dict[str, Any]]: + """Rank every indexed file by relevance to a free-text ``query``. + + Runs three aggregate Cypher reads (files, symbols, cross-file edge degree), + scores files with the weighted hybrid above, and returns an ordered list of + ``{abs_path, file, score, name, src_start, src_end}`` (best representative + symbol per file, used for the snippet). Pure read; no graph mutation. + """ + files, comps, rep, abs_of, file_id_of = await _hybrid_components(g, query, project) + scored = _hybrid_score(files, comps, rep, abs_of, file_id_of) + # Relevance floor: a query with no lexical overlap (e.g. a nonsense token) + # would otherwise be ranked purely by query-independent centrality and + # return noise. Drop files with no lexical signal so such queries yield []. + return [r for r in scored if r.get("lex", 0.0) > 0] + + +async def _hybrid_components(g, query: str, project: Optional[str]): + """Fetch graph data and build per-file, weight-independent components. + + Returns ``(files, comps, rep, abs_of, file_id_of)`` where ``comps[f]`` holds + the min-max-normalized ``name``/``path``/``bm25``/``cent`` scores plus the + ``pen`` penalty, and ``file_id_of[f]`` is the File node id (handle the agent + feeds to ``get_file_neighbors``). Separated from weighting so weight sweeps + reuse the exact same normalized inputs as the live ranker. + """ + def rel(p: Optional[str]) -> str: + return _relativize(p, project) if p else "" + + qtok = set(_tokenize(query)) + qids = _issue_identifiers(query) + + pathtok: dict[str, list[str]] = {} + bodytok: dict[str, list[str]] = defaultdict(list) + abs_of: dict[str, str] = {} + file_id_of: dict[str, Any] = {} + name_exact: dict[str, float] = defaultdict(float) + rep: dict[str, dict[str, Any]] = {} + + files_res = await g._query("MATCH (f:File) RETURN f.path, ID(f)") + for row in files_res.result_set: + ap = row[0] + if not ap: + continue + rp = rel(ap) + abs_of[rp] = ap + file_id_of[rp] = row[1] + pt = _tokenize(rp.replace("/", " ")) + pathtok[rp] = pt + bodytok[rp].extend(pt) + + sym_res = await g._query( + "MATCH (n) WHERE n:Function OR n:Class " + "RETURN n.name, n.path, n.doc, n.src_start, n.src_end" + ) + body_used: Counter = Counter() + for name, path, doc, start, end in sym_res.result_set: + if not path: + continue + rp = rel(path) + # Rank only over real ``File`` nodes; symbols whose containing file was + # not emitted as a File node would otherwise inflate the BM25 corpus and + # skew min-max normalization. + if rp not in abs_of: + continue + if name: + bodytok[rp].extend(_subtokens(name)) + is_exact = name.lower() in qids + if is_exact: + name_exact[rp] += 1.0 + # Representative symbol for the file's snippet: prefer one whose name + # exactly matches a query identifier, otherwise the lowest-``src_start`` + # symbol. Fully deterministic regardless of result-set order via a + # stable sort key (exact first, then src_start, then name, then + # src_end) so ties / missing ``src_start`` never depend on row order. + cand = {"name": name, "src_start": start, "src_end": end, + "exact": is_exact} + cur = rep.get(rp) + if cur is None or _rep_key(cand) < _rep_key(cur): + rep[rp] = cand + if doc and body_used[rp] < _HYBRID_BODY_TOKEN_CAP: + toks = _tokenize(doc)[: _HYBRID_BODY_TOKEN_CAP - body_used[rp]] + bodytok[rp].extend(toks) + body_used[rp] += len(toks) + + deg_res = await g._query( + "MATCH (a)-[:CALLS|IMPORTS|EXTENDS|OVERRIDES]->(b) " + "WHERE a.path IS NOT NULL AND b.path IS NOT NULL AND a.path <> b.path " + "RETURN b.path AS p, count(*) AS deg" + ) + centrality: dict[str, float] = defaultdict(float) + for bpath, deg in deg_res.result_set: + centrality[rel(bpath)] += math.log1p(int(deg or 0)) + + files = sorted(abs_of) + if not files: + return [], {}, {}, {}, {} + + path_overlap = {f: float(len(qtok & set(pathtok.get(f, [])))) for f in files} + raw_bm25 = _bm25(qtok, files, bodytok) + n_name = _minmax(name_exact if name_exact else {f: 0.0 for f in files}) + n_path = _minmax(path_overlap) + n_bm25 = _minmax(raw_bm25) + n_cent = _minmax({f: centrality.get(f, 0.0) for f in files}) + + comps: dict[str, dict[str, float]] = {} + for f in files: + comps[f] = { + "name": n_name.get(f, 0.0), + "path": n_path.get(f, 0.0), + "bm25": n_bm25.get(f, 0.0), + "cent": n_cent.get(f, 0.0), + "pen": _HYBRID_W_PEN if _PENALTY_RE.search(f) else 0.0, + # Raw (un-normalized) query-dependent signal. A file with zero + # lexical overlap (name/path/body) is not relevant to the query — + # only query-independent centrality could rank it — so search_code + # drops it rather than returning noise for an unmatched query. + "lex": (name_exact.get(f, 0.0) + + path_overlap.get(f, 0.0) + + raw_bm25.get(f, 0.0)), + } + + return files, comps, rep, abs_of, file_id_of + + +def _hybrid_score( + files: list[str], + comps: dict[str, dict[str, float]], + rep: dict[str, dict[str, Any]], + abs_of: dict[str, str], + file_id_of: Optional[dict[str, Any]] = None, +) -> list[dict[str, Any]]: + """Apply the hybrid weights to pre-normalized per-file components. + + Split out from ``_hybrid_rank`` so the weighting can be exercised + independently of the (expensive) graph reads and normalization. + """ + file_id_of = file_id_of or {} + scored: list[dict[str, Any]] = [] + for f in files: + c = comps[f] + score = ( + _HYBRID_W_NAME * c["name"] + + _HYBRID_W_PATH * c["path"] + + _HYBRID_W_BM25 * c["bm25"] + + _HYBRID_W_CENT * c["cent"] + - c["pen"] + ) + r = rep.get(f, {}) + scored.append({ + "abs_path": abs_of[f], + "file": f, + "file_id": file_id_of.get(f), + "score": round(score, 4), + "name": r.get("name"), + "src_start": r.get("src_start"), + "src_end": r.get("src_end"), + "lex": c.get("lex", 0.0), + }) + scored.sort(key=lambda d: -d["score"]) + return scored + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -223,13 +503,119 @@ def _project_arg(project: str, branch: Optional[str]): return AsyncGraphQuery(project, branch=branch) -def _node_summary(n: Any) -> dict[str, Any]: +def _relativize(path: Optional[str], rel_to: Optional[str]) -> Optional[str]: + """Strip the indexing-root prefix so file paths are repo-relative. + + Stored paths are absolute and include the worktree root whose final + segment equals the ``project`` identifier + (e.g. ``///pkg/mod.py``). We strip up to and including the + FIRST ``//`` occurrence so the result is the repo-relative path. + ``find`` (first match) is deliberate: a nested directory may legitimately + repeat the project name and must be preserved in the relative path. + + Absolute worktree prefixes are ~150 chars/row of pure noise that the agent + re-reads on every cached turn; relativizing is the cheapest token win. + """ + if not path or not rel_to: + return path + marker = f"/{rel_to}/" + idx = path.find(marker) + if idx == -1: + return path + return path[idx + len(marker):] + + +def _raw_name(n: Any) -> Optional[str]: + """Extract the ``name`` property from a Node or already-encoded dict.""" + if hasattr(n, "properties"): + return (n.properties or {}).get("name") + return (dict(n).get("properties") or {}).get("name") + + +# Snippet windows (lines) attached to results so the agent can judge relevance +# without a follow-up full-file ``view`` — the dominant token cost. Neighbor +# tools return small, high-relevance result sets, so they carry a real window; +# search_code can return many name matches, so it carries only a signature. +NEIGHBOR_SNIPPET_LINES = 12 +SEARCH_SNIPPET_LINES = 2 +_SNIPPET_LINE_CHARS = 200 +"""Per-line char cap so a single minified line cannot blow up the payload.""" + + +def _read_snippet( + abs_path: Optional[str], + src_start: Any, + src_end: Any, + max_lines: int, +) -> Optional[str]: + """Read up to ``max_lines`` leading source lines for an entity from disk. + + Paths and line numbers are stored on nodes but the source text is not, so + we read the file lazily. Returns the entity's leading lines (signature plus + the start of the body) joined by newlines, or ``None`` when the file/line + info is missing or unreadable. Each line is capped at + ``_SNIPPET_LINE_CHARS``. Carrying a snippet in the result lets a single + tool call replace a follow-up ``view`` of a large file. + """ + if not abs_path or max_lines <= 0: + return None + try: + start = int(src_start) + except (TypeError, ValueError): + return None + if start < 1: + return None + try: + end = int(src_end) + except (TypeError, ValueError): + end = start + # Read a slightly larger window than ``max_lines`` (bounded by the entity's + # own extent) so leading blank / decorator-only lines can be skipped without + # starving the substantive output. + hard_end = end if end >= start else start + read_until = min(hard_end, start + max_lines + 5) + raw_lines: list[str] = [] + try: + with open(abs_path, "r", encoding="utf-8", errors="replace") as fh: + for i, raw in enumerate(fh, start=1): + if i < start: + continue + if i > read_until: + break + raw_lines.append(raw.rstrip("\n")) + except OSError: + return None + # Skip leading blank / decorator-only lines so the limited window lands on + # the substantive signature + body rather than being wasted on a blank line + # or a bare ``@decorator``. + while raw_lines and ( + not raw_lines[0].strip() or raw_lines[0].lstrip().startswith("@") + ): + raw_lines.pop(0) + lines = [ln[:_SNIPPET_LINE_CHARS] for ln in raw_lines[:max_lines]] + if not lines: + return None + return "\n".join(lines) + + +def _node_summary( + n: Any, + rel_to: Optional[str] = None, + snippet_lines: int = 0, + with_label: bool = True, +) -> dict[str, Any]: """Normalize a FalkorDB Node (or already-encoded dict) to a flat payload. ``encode_node`` returns ``{id, labels, properties: {...}}`` because Node - properties live on a nested attribute. Agents want a flat record, and - they also want a single ``label`` (the meaningful one — File, Class, - Function — not the fulltext-index marker ``Searchable``). + properties live on a nested attribute. Agents want a flat record. We keep + the single meaningful ``label`` (File, Class, Function — not the fulltext + marker ``Searchable``) for ``search_code`` (File-vs-Function disambiguation) + and for ``find_path`` (a path is bare nodes with no per-hop relation, so the + label is the only type signal). Single-hop neighbor results omit it via + ``with_label`` since the relation already implies the node type. + + When ``rel_to`` is given (the project/worktree identifier), ``file`` is + relativized to drop the absolute worktree prefix. """ if hasattr(n, "properties"): props = dict(n.properties or {}) @@ -242,13 +628,24 @@ def _node_summary(n: Any) -> dict[str, Any]: node_id = d.get("id") label = next((lbl for lbl in labels if lbl != "Searchable"), None) - return { + summary: dict[str, Any] = { "id": node_id, "name": props.get("name"), - "label": label, - "file": props.get("path"), - "line": props.get("src_start"), } + if with_label: + summary["label"] = label + summary["file"] = _relativize(props.get("path"), rel_to) + summary["line"] = props.get("src_start") + if snippet_lines > 0: + snip = _read_snippet( + props.get("path"), + props.get("src_start"), + props.get("src_end"), + snippet_lines, + ) + if snip: + summary["snippet"] = snip + return summary # Relationship-type names are graph labels (SCREAMING_SNAKE_CASE, e.g. CALLS, @@ -324,7 +721,12 @@ async def _neighbors_payload( res = await g._query(q, {"sid": node_id, "limit": int(limit)}) out: list[dict[str, Any]] = [] for row in res.result_set: - entry = _node_summary(row[0]) + entry = _node_summary( + row[0], + rel_to=project, + snippet_lines=NEIGHBOR_SNIPPET_LINES, + with_label=False, + ) entry["relation"] = row[1] entry["direction"] = direction out.append(entry) @@ -333,29 +735,183 @@ async def _neighbors_payload( await g.close() +FIND_SYMBOL_SNIPPET_LINES = 4 +FIND_SYMBOL_DB_LIMIT = 200 + + +def _clamp_find_symbol_limit(limit: Any) -> int: + """Coerce ``limit`` to ``1..FIND_SYMBOL_DB_LIMIT``. + + Agents may hand back a stringified or out-of-range ``limit``; a negative + value would otherwise flip ``rows[:limit]`` into a surprising tail slice and + a non-integer would fail deep in the slice with an opaque error. Strings of + digits are accepted; anything else is rejected up front. + """ + if isinstance(limit, bool): + raise ValueError(f"limit must be an integer, got bool: {limit!r}") + if isinstance(limit, str) and limit.lstrip("-").isdigit(): + limit = int(limit) + if not isinstance(limit, int): + raise ValueError(f"limit must be an integer, got: {limit!r}") + if limit < 1: + return 1 + if limit > FIND_SYMBOL_DB_LIMIT: + return FIND_SYMBOL_DB_LIMIT + return limit + + @app.tool( - name="get_callers", + name="find_symbol", description=( - "Return functions that call the given symbol (incoming CALLS edges). " - "`symbol_id` is the integer node id returned by `search_code` or " - "other tools." + "Resolve a Function/Class NAME to its integer symbol node id — the " + "bridge from a human-readable name to the `symbol_id` that " + "`get_neighbors`, `impact_analysis` and `find_path` require (those tools " + "take an id, NOT a name; `search_code` returns FILES, not symbol ids, so " + "start here for relationship questions). Pass the simple name " + "(e.g. `normalize_cartesian_coordinates`); a dotted qualname " + "(`Grid.normalize_cartesian_coordinates`) is accepted and its last " + "segment is used. Optionally pass `file` (repo-relative path or a " + "substring of it) to disambiguate same-named symbols — matches in that " + "file are flagged `file_match: true` and listed first. Returns " + "[{symbol_id, name, label, file, line, file_match, snippet}] ordered " + "best-scoped first; when more than one remains, disambiguate by `file` " + "and `snippet`. Feed the chosen `symbol_id` to the relationship tools." ), ) -async def get_callers( - symbol_id: int | str, +async def find_symbol( + name: str, project: str, + file: Optional[str] = None, branch: Optional[str] = None, - limit: int = 50, + limit: int = 20, ) -> list[dict[str, Any]]: - return await _neighbors_payload(project, branch, symbol_id, "CALLS", "IN", limit) + """Look up Function/Class nodes by their simple name. + + Symbol nodes store only the SIMPLE name (``foo``), never the qualname + (``Bar.foo``), so a dotted input is reduced to its last segment. Every + candidate carries a ``file_match`` flag (always ``False`` when no ``file`` + filter is requested); when ``file`` is given we do not silently widen the + search — the in-file ones sort first, so a wrong/empty file filter degrades + visibly (the agent still sees the global matches but knows none were in the + requested file) rather than routing a relationship query to an arbitrary + same-named symbol. + """ + simple = str(name).strip().split(".")[-1].strip() + if not simple: + raise ValueError( + f"name must be a non-empty symbol name, got: {name!r}" + ) + eff_limit = _clamp_find_symbol_limit(limit) + g = _project_arg(project, branch) + try: + res = await g._query( + "MATCH (n) WHERE (n:Function OR n:Class) AND n.name = $name " + "RETURN n LIMIT $limit", + {"name": simple, "limit": FIND_SYMBOL_DB_LIMIT}, + ) + rows = [ + _node_summary( + r[0], + rel_to=project, + snippet_lines=FIND_SYMBOL_SNIPPET_LINES, + ) + for r in res.result_set + ] + finally: + await g.close() + + for r in rows: + r["symbol_id"] = r.pop("id") + + needle = str(file).strip().lstrip("/") if file is not None else "" + + def _matches(r: dict[str, Any]) -> bool: + fp = r.get("file") or "" + return bool(needle) and (fp == needle or fp.endswith(needle) or needle in fp) + + # ``file_match`` is part of the documented response shape, so set it on + # every row (False when no ``file`` filter is requested) rather than only + # when ``file`` is given. Sort deterministically — in-file matches first, + # then by file/line/name/id — so ordering never depends on FalkorDB row + # order between runs. + for r in rows: + r["file_match"] = _matches(r) + rows.sort(key=lambda r: ( + not r["file_match"], + r.get("file") or "", + r["line"] if r.get("line") is not None else math.inf, + r.get("name") or "", + r["symbol_id"], + )) + + return rows[:eff_limit] @app.tool( - name="get_callees", + name="get_neighbors", description=( - "Return functions that the given symbol calls (outgoing CALLS edges)." + "Adjacent symbols of a SYMBOL node (Function/Class) via graph edges, by " + "its integer node id. Pick relation+direction: " + "`CALLS`+`IN`=callers (upstream co-change sites); `CALLS`+`OUT`=callees; " + "`IMPORTS`+`IN`=importers of a File; `OVERRIDES`+`BOTH`=polymorphic " + "dispatch that plain CALLS miss; `[IMPORTS,CALLS,DEFINES]`+`OUT`=" + "dependencies. `relation` is one name or a list; `direction` is IN, OUT, " + "or BOTH. Each result carries a code `snippet`, so you rarely need to " + "`view` the file. To navigate from a `search_code` hit (a FILE), use " + "`get_file_neighbors` with its `file_id` instead — this tool expects a " + "symbol id." ), ) +async def get_neighbors( + symbol_id: Any, + project: str, + relation: Any = "CALLS", + direction: str = "OUT", + branch: Optional[str] = None, + limit: int = 50, +) -> list[dict[str, Any]]: + """Unified single-hop neighbor traversal (replaces the per-edge tools). + + ``relation`` accepts a single edge type or a list of them; results are + aggregated across relations and deduped. ``direction`` is ``IN`` + (incoming edges), ``OUT`` (outgoing), or ``BOTH`` (union of IN+OUT). + """ + rels = [relation] if isinstance(relation, str) else list(relation) + direction = (direction or "OUT").upper() + if direction == "BOTH": + dirs = ["OUT", "IN"] + elif direction in ("IN", "OUT"): + dirs = [direction] + else: + raise ValueError( + f"direction must be 'IN', 'OUT', or 'BOTH', got: {direction!r}" + ) + + seen: set[Any] = set() + out: list[dict[str, Any]] = [] + for d in dirs: + for rel in rels: + rows = await _neighbors_payload(project, branch, symbol_id, rel, d, limit) + for row in rows: + key = (row.get("id"), row.get("relation"), row.get("direction")) + if key in seen: + continue + seen.add(key) + out.append(row) + if len(out) >= limit: + return out + return out + + +async def get_callers( + symbol_id: int | str, + project: str, + branch: Optional[str] = None, + limit: int = 50, +) -> list[dict[str, Any]]: + return await _neighbors_payload(project, branch, symbol_id, "CALLS", "IN", limit) + + async def get_callees( symbol_id: int | str, project: str, @@ -365,14 +921,6 @@ async def get_callees( return await _neighbors_payload(project, branch, symbol_id, "CALLS", "OUT", limit) -@app.tool( - name="get_dependencies", - description=( - "Return outgoing neighbors of the given symbol across any of the " - "specified relation types (default: IMPORTS, CALLS, DEFINES). " - "Useful for 'what does this depend on' queries." - ), -) async def get_dependencies( symbol_id: int | str, project: str, @@ -405,6 +953,55 @@ async def get_dependencies( return out +# --------------------------------------------------------------------------- +# Spike 1a — get_importers (incoming IMPORTS) / get_overrides (OVERRIDES) +# --------------------------------------------------------------------------- + + +async def get_importers( + symbol_id: Any, + project: str, + branch: Optional[str] = None, + limit: int = 50, +) -> list[dict[str, Any]]: + return await _neighbors_payload( + project, branch, symbol_id, "IMPORTS", "IN", limit + ) + + +async def get_overrides( + symbol_id: Any, + project: str, + branch: Optional[str] = None, + direction: str = "BOTH", + limit: int = 50, +) -> list[dict[str, Any]]: + direction = (direction or "BOTH").upper() + if direction in ("IN", "OUT"): + return await _neighbors_payload( + project, branch, symbol_id, "OVERRIDES", direction, limit + ) + if direction != "BOTH": + raise ValueError( + f"direction must be 'IN', 'OUT', or 'BOTH', got: {direction!r}" + ) + seen: set[Any] = set() + out: list[dict[str, Any]] = [] + for d in ("OUT", "IN"): + rows = await _neighbors_payload( + project, branch, symbol_id, "OVERRIDES", d, limit + ) + for row in rows: + key = (row.get("id"), row.get("direction")) + if key in seen: + continue + seen.add(key) + out.append(row) + if len(out) >= limit: + return out + return out + + # --------------------------------------------------------------------------- # T7 — find_path # --------------------------------------------------------------------------- @@ -414,8 +1011,10 @@ async def get_dependencies( name="find_path", description=( "Return up to `max_paths` CALLS-path sequences from `source_id` to " - "`dest_id`. Useful for 'how does A reach B' questions. Returns an " - "empty list when no path exists." + "`dest_id` ('how does A reach B'). Use to confirm whether a suspected " + "entry point actually reaches a suspected buggy function, and through " + "which intermediaries. Returns an empty list when no STATIC path exists " + "(dynamic dispatch is not captured)." ), ) async def find_path( @@ -441,7 +1040,7 @@ async def find_path( paths: list[dict[str, Any]] = [] for entry in raw: node_seq = [ - _node_summary(x) + _node_summary(x, rel_to=project, with_label=True) for x in entry # Discriminate on ``labels``: ``encode_node`` emits a top-level # ``labels`` key, while ``encode_edge`` does not (edges carry @@ -463,32 +1062,227 @@ async def find_path( @app.tool( name="search_code", description=( - "Prefix-search for symbols (functions, classes, files) whose name " - "starts with `prefix`. Backed by FalkorDB's full-text index. The " - "agent typically calls this first to discover symbol ids for the " - "navigation tools (`get_callers`, `find_path`, ...)." + "Localize a bug/feature to its files from a CONCEPTUAL free-text query. " + "Phrase it as a natural-language description of the behavior and area " + "involved (e.g. 'face centroid computation uses node connectivity' or " + "'tagging a library entry duplicates tags') — the issue title plus a " + "phrase about what the code DOES. Backticked symbol/error names are fine " + "as seasoning, but DO NOT pass a bare list of identifiers: a pile of exact " + "symbol names collapses the ranking onto their single definition file and " + "hides the related files you didn't know to name (use grep if you already " + "know the exact symbol). " + "Ranks every indexed file by a hybrid of (exact symbol-name match, " + "path-token overlap, BM25 over symbol names + docstrings, and call-graph " + "centrality), de-prioritizing test/vendored paths. Returns the top files " + "as {file, file_id, score, name, line, snippet} — best candidates first, " + "so the usual top 3-5 are where to start. Unlike a name lookup, it " + "surfaces the right file even when you don't know the exact symbol. Feed a " + "result's `file_id` to `get_file_neighbors` to reveal the files " + "structurally coupled to it (imports/calls) — co-change candidates a " + "textual search misses." ), ) async def search_code( - prefix: str, + query: str, project: str, branch: Optional[str] = None, - limit: int = 20, + limit: int = 10, ) -> list[dict[str, Any]]: g = _project_arg(project, branch) try: - # Push the caller's ``limit`` down to the DB so it is actually honored - # (the underlying full-text query is otherwise capped at its default). - raw = await g.prefix_search(prefix, limit=limit) + ranked = await _hybrid_rank(g, query, project) finally: await g.close() - return [_node_summary(node) for node in raw] + out: list[dict[str, Any]] = [] + for r in ranked[:limit]: + rec: dict[str, Any] = { + "file": r["file"], + "file_id": r["file_id"], + "score": r["score"], + "name": r["name"], + "line": r["src_start"], + "label": "File", + } + snip = _read_snippet( + r["abs_path"], + r["src_start"] if r["src_start"] is not None else 1, + r["src_end"] if r["src_end"] is not None else r["src_start"], + SEARCH_SNIPPET_LINES, + ) + if snip: + rec["snippet"] = snip + out.append(rec) + return out # --------------------------------------------------------------------------- -# T6 — impact_analysis (variable-depth Cypher with DISTINCT for cycle safety) +# T8b — get_file_neighbors (file-level structural coupling) # --------------------------------------------------------------------------- +FILE_NEIGHBOR_RELS = ("IMPORTS", "CALLS", "EXTENDS", "OVERRIDES") +FILE_NEIGHBOR_MAX = 100 +"""Hard cap on returned neighbor files. The default is intentionally high: +the value of this tool is a candidate set GUARANTEED to contain the coupled +file, so truncating it (and possibly dropping that file) defeats the purpose. +``truncated`` is surfaced so the agent knows when the cap bit.""" + + +async def _resolve_file( + g, file: Any, project: Optional[str] +) -> tuple[Optional[int], Optional[str]]: + """Resolve a File handle to ``(file_node_id, abs_path)``. + + ``file`` is either the integer File-node id that ``search_code`` returns, + or a repo-relative path. Path resolution compares ``_relativize`` of each + stored ``File.path`` rather than reconstructing an absolute path (worktree + roots vary), so a relative path matches regardless of the indexing root. + """ + try: + fid = _coerce_node_id(file) + except ValueError: + fid = None + if fid is not None: + res = await g._query( + "MATCH (f:File) WHERE ID(f) = $id RETURN f.path", {"id": fid} + ) + if res.result_set and res.result_set[0][0]: + return fid, res.result_set[0][0] + return None, None + target = str(file).strip().lstrip("/") + res = await g._query("MATCH (f:File) RETURN ID(f), f.path") + for nid, ap in res.result_set: + if ap and (_relativize(ap, project) == target or ap == file): + return nid, ap + return None, None + + +@app.tool( + name="get_file_neighbors", + description=( + "Files structurally coupled to a FILE — the import/call/inheritance " + "dependencies that a textual search misses and that must often change " + "together. Run after `search_code`, passing a hit's `file_id` (or a " + "repo-relative path). Expands EVERY symbol in the file (not just the " + "representative one) and unions 1-hop IMPORTS/CALLS/EXTENDS/OVERRIDES " + "edges in both directions, deduped to files and ordered by coupling " + "strength (edge count). Returns {file, total_neighbors, truncated, " + "neighbors:[{file, file_id, edge_count, relations, snippet}]}. Each " + "neighbor carries a `file_id` you can recurse on and a code `snippet`. " + "Use this to reach co-change files after localizing the primary hit." + ), +) +async def get_file_neighbors( + file: Any, + project: str, + branch: Optional[str] = None, + limit: int = FILE_NEIGHBOR_MAX, +) -> dict[str, Any]: + g = _project_arg(project, branch) + try: + fid, abs_path = await _resolve_file(g, file, project) + if abs_path is None: + return { + "file": _relativize(str(file), project), + "file_id": None, + "total_neighbors": 0, + "truncated": False, + "neighbors": [], + } + + sres = await g._query( + "MATCH (n) WHERE (n:Function OR n:Class) AND n.path = $p RETURN ID(n)", + {"p": abs_path}, + ) + ids = [row[0] for row in sres.result_set] + if fid is not None: + ids.append(fid) + if not ids: + return { + "file": _relativize(abs_path, project), + "file_id": fid, + "total_neighbors": 0, + "truncated": False, + "neighbors": [], + } + + relq = "|".join(FILE_NEIGHBOR_RELS) + # Aggregate first (group by neighbor file), sort second, limit last — so + # high-degree symbols can't crowd the coupled file out before ranking. + agg: dict[str, dict[str, Any]] = {} + for direction in ("OUT", "IN"): + if direction == "OUT": + q = ( + f"MATCH (n)-[e:{relq}]->(d) WHERE ID(n) IN $ids " + "AND d.path IS NOT NULL AND d.path <> $self " + "RETURN d.path AS p, type(e) AS rel" + ) + else: + q = ( + f"MATCH (s)-[e:{relq}]->(n) WHERE ID(n) IN $ids " + "AND s.path IS NOT NULL AND s.path <> $self " + "RETURN s.path AS p, type(e) AS rel" + ) + res = await g._query(q, {"ids": ids, "self": abs_path}) + for ap, rel in res.result_set: + rp = _relativize(ap, project) + slot = agg.get(rp) + if slot is None: + slot = {"abs": ap, "count": 0, "rels": Counter()} + agg[rp] = slot + slot["count"] += 1 + slot["rels"][f"{rel}:{direction}"] += 1 + + total = len(agg) + eff_limit = min(int(limit or FILE_NEIGHBOR_MAX), FILE_NEIGHBOR_MAX) + ordered = sorted(agg.items(), key=lambda kv: (-kv[1]["count"], kv[0])) + kept = ordered[:eff_limit] + kept_abs = [slot["abs"] for _, slot in kept] + + # Batch the File-id and representative-symbol (lowest src_start) lookups + # for only the kept neighbors. + fid_of: dict[str, int] = {} + if kept_abs: + fres = await g._query( + "MATCH (f:File) WHERE f.path IN $paths RETURN f.path, ID(f)", + {"paths": kept_abs}, + ) + fid_of = {row[0]: row[1] for row in fres.result_set} + rep_of: dict[str, tuple] = {} + if kept_abs: + rres = await g._query( + "MATCH (n) WHERE (n:Function OR n:Class) AND n.path IN $paths " + "RETURN n.path, n.src_start, n.src_end ORDER BY n.src_start", + {"paths": kept_abs}, + ) + for ap, s0, s1 in rres.result_set: + rep_of.setdefault(ap, (s0, s1)) + + neighbors: list[dict[str, Any]] = [] + for rp, slot in kept: + ap = slot["abs"] + entry: dict[str, Any] = { + "file": rp, + "file_id": fid_of.get(ap), + "edge_count": slot["count"], + "relations": dict(slot["rels"]), + } + rep = rep_of.get(ap) + if rep: + snip = _read_snippet(ap, rep[0], rep[1], NEIGHBOR_SNIPPET_LINES) + if snip: + entry["snippet"] = snip + neighbors.append(entry) + + return { + "file": _relativize(abs_path, project), + "file_id": fid, + "total_neighbors": total, + "truncated": total > eff_limit, + "neighbors": neighbors, + } + finally: + await g.close() + # Hard cap on traversal depth — passed values above this are silently # clamped. Prevents pathological queries (e.g. depth=999) from hammering @@ -515,13 +1309,15 @@ def _clamp_depth(depth: Any) -> int: @app.tool( name="impact_analysis", description=( - "Transitive call-graph impact for refactoring: " - "`direction='IN'` returns all upstream callers (what breaks if you " - "change this symbol); `direction='OUT'` returns all downstream " - "callees (what this symbol indirectly depends on). Traverses only " - f"CALLS edges. Depth is clamped to {IMPACT_MAX_DEPTH}; cycles are " - "deduplicated via Cypher DISTINCT (each node appears at most once). " - "`limit` bounds the number of impacted symbols returned." + "Find the files/functions connected to a symbol through the call graph " + "— the files that must change TOGETHER when fixing or modifying it. " + "`direction='IN'` = upstream callers (who depends on this); " + "`direction='OUT'` = downstream callees (what this relies on). " + f"Traverses only CALLS edges (static; dynamic dispatch not captured); " + f"depth clamped to {IMPACT_MAX_DEPTH}; cycles deduplicated via Cypher " + "DISTINCT. `limit` bounds the number of impacted symbols returned. " + "For localization, run IN on a confirmed-buggy symbol to " + "surface co-change files." ), ) async def impact_analysis( @@ -565,7 +1361,7 @@ async def impact_analysis( out: list[dict[str, Any]] = [] for row in res.result_set: - entry = _node_summary(row[0]) + entry = _node_summary(row[0], rel_to=project, with_label=False) entry["direction"] = direction out.append(entry) return out diff --git a/api/migrations/per_branch.py b/api/migrations/per_branch.py index b263369e..d70ee5b0 100644 --- a/api/migrations/per_branch.py +++ b/api/migrations/per_branch.py @@ -20,11 +20,11 @@ from __future__ import annotations import logging -import os from typing import Iterable from falkordb import FalkorDB +from ..db import create_falkordb from ..graph import ( DEFAULT_BRANCH, compose_graph_name, @@ -37,12 +37,7 @@ def _connect() -> FalkorDB: - return FalkorDB( - host=os.getenv("FALKORDB_HOST", "localhost"), - port=os.getenv("FALKORDB_PORT", 6379), - username=os.getenv("FALKORDB_USERNAME", None), - password=os.getenv("FALKORDB_PASSWORD", None), - ) + return create_falkordb() def _legacy_graphs(all_graphs: Iterable[str]) -> list[str]: diff --git a/benchmarks/mcp_context_benchmark.py b/benchmarks/mcp_context_benchmark.py new file mode 100644 index 00000000..e04e33b2 --- /dev/null +++ b/benchmarks/mcp_context_benchmark.py @@ -0,0 +1,203 @@ +"""Estimate context/token savings from CodeGraph MCP-style lookups. + +This benchmark is intentionally deterministic: it does not call an LLM and +does not require provider API keys. It compares the approximate context size an +agent would send when answering 100 code-navigation questions in two modes: + +* graph: compact structured context from CodeGraph MCP tools +* plain: raw source-file context from the likely file an agent would inspect + +Run after indexing the repo/folder, for example: + + CODE_GRAPH_DB_BACKEND=lite FALKORDB_LITE_PATH=/tmp/code-graph-lite.rdb \ + uv run --extra light cgraph index api --repo code-graph-api --branch local-test + + CODE_GRAPH_DB_BACKEND=lite FALKORDB_LITE_PATH=/tmp/code-graph-lite.rdb \ + uv run --extra light python benchmarks/mcp_context_benchmark.py \ + --project code-graph-api --branch local-test --root api +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +logging.disable(logging.CRITICAL) + +from api.mcp.tools.structural import find_symbol, get_file_neighbors, get_neighbors # noqa: E402 + + +@dataclass(frozen=True) +class Target: + subject: str + symbol: str + file: str + + +@dataclass(frozen=True) +class Question: + text: str + target: Target + + +TARGETS = [ + Target("Graph class ownership and methods", "Graph", "api/graph.py"), + Target("AsyncGraphQuery neighbor reads", "AsyncGraphQuery", "api/graph.py"), + Target("regular-vs-Lite backend selection", "create_falkordb", "api/db.py"), + Target("GraphRAG host/port configuration", "graphrag_connection_kwargs", "api/db.py"), + Target("MCP repository indexing", "index_repo", "api/mcp/tools/structural.py"), + Target("symbol resolution before traversal", "find_symbol", "api/mcp/tools/structural.py"), + Target("single-hop graph traversal", "get_neighbors", "api/mcp/tools/structural.py"), + Target("project source analysis orchestration", "Project", "api/project.py"), + Target("local folder analyzer pipeline", "SourceAnalyzer", "api/analyzers/source_analyzer.py"), + Target("C# analyzer restore behavior", "CSharpAnalyzer", "api/analyzers/csharp/analyzer.py"), +] + +QUESTION_TEMPLATES = [ + "What does {subject} do?", + "Which symbols are directly related to {subject}?", + "What would likely be impacted if {subject} changed?", + "Which callers or incoming relationships does {subject} have?", + "Which dependencies or outgoing relationships does {subject} have?", + "Where is {subject} defined and what is its local context?", + "What methods or child symbols does {subject} define?", + "How should an agent navigate from {subject} to related code?", + "What source file should be inspected for {subject}?", + "Summarize {subject} using only relevant structural context.", +] + +QUESTIONS = [ + Question(template.format(subject=target.subject), target) + for target in TARGETS + for template in QUESTION_TEMPLATES +] + + +def estimate_tokens(text: str) -> int: + """Cheap token approximation for relative comparisons.""" + return max(1, len(text) // 4) + + +def compact_json(data: Any) -> str: + return json.dumps(data, ensure_ascii=False, sort_keys=True, default=str) + + +def read_plain_context(root: Path, file_hint: str) -> str: + path = root.parent / file_hint if file_hint.startswith(f"{root.name}/") else root / file_hint + if not path.exists(): + path = root.parent / file_hint + return path.read_text(errors="replace") + + +async def graph_context(question: Question, project: str, branch: str | None) -> dict[str, Any]: + symbols = await find_symbol( + name=question.target.symbol, + project=project, + branch=branch, + file=question.target.file, + limit=3, + ) + payload: dict[str, Any] = {"symbols": symbols} + if not symbols: + return payload + + symbol_id = symbols[0]["symbol_id"] + payload["defines_out"] = await get_neighbors( + symbol_id=symbol_id, + project=project, + branch=branch, + relation="DEFINES", + direction="OUT", + limit=30, + ) + payload["calls_in"] = await get_neighbors( + symbol_id=symbol_id, + project=project, + branch=branch, + relation="CALLS", + direction="IN", + limit=15, + ) + payload["calls_out"] = await get_neighbors( + symbol_id=symbol_id, + project=project, + branch=branch, + relation="CALLS", + direction="OUT", + limit=15, + ) + payload["file_neighbors"] = await get_file_neighbors( + file=question.target.file, + project=project, + branch=branch, + limit=20, + ) + return payload + + +async def run(project: str, branch: str | None, root: Path) -> None: + rows = [] + graph_total = 0 + plain_total = 0 + graph_contexts: dict[Target, str] = {} + + for question in QUESTIONS: + if question.target not in graph_contexts: + graph_contexts[question.target] = compact_json( + await graph_context(question, project, branch) + ) + graph_text = graph_contexts[question.target] + plain_text = read_plain_context(root, question.target.file) + graph_tokens = estimate_tokens(graph_text) + plain_tokens = estimate_tokens(plain_text) + graph_total += graph_tokens + plain_total += plain_tokens + rows.append( + { + "question": question.text, + "symbol": question.target.symbol, + "file": question.target.file, + "graph_tokens": graph_tokens, + "plain_tokens": plain_tokens, + "delta_tokens": plain_tokens - graph_tokens, + "reduction_pct": round((1 - graph_tokens / plain_tokens) * 100, 1) + if plain_tokens + else 0, + } + ) + + summary = { + "project": project, + "branch": branch, + "questions": len(QUESTIONS), + "graph_tokens": graph_total, + "plain_tokens": plain_total, + "delta_tokens": plain_total - graph_total, + "reduction_pct": round((1 - graph_total / plain_total) * 100, 1) + if plain_total + else 0, + } + print(json.dumps({"summary": summary, "rows": rows}, indent=2)) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--project", required=True, help="Indexed CodeGraph project name") + parser.add_argument("--branch", default=None, help="Indexed branch name") + parser.add_argument( + "--root", + type=Path, + default=Path("api"), + help="Source root used for the plain-context baseline", + ) + args = parser.parse_args() + asyncio.run(run(args.project, args.branch, args.root.resolve())) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index c4eddd28..37c30978 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,9 @@ cgraph = "api.cli:app" cgraph-mcp = "api.mcp.server:main" [project.optional-dependencies] +light = [ + "falkordblite>=0.10.0,<1.0.0", +] test = [ "pytest>=9.0.2,<10.0.0", "ruff>=0.11.0,<1.0.0", diff --git a/tests/mcp/test_find_symbol.py b/tests/mcp/test_find_symbol.py new file mode 100644 index 00000000..20427a4f --- /dev/null +++ b/tests/mcp/test_find_symbol.py @@ -0,0 +1,221 @@ +"""T9 — find_symbol MCP tool tests. + +Covers name resolution (simple + dotted qualname reduction), the ``file`` +disambiguation flag + ordering, response-shape stability, deterministic order, +and ``limit`` / empty-name validation. +""" + +from __future__ import annotations + +import uuid + +import pytest + + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +def anyio_backend() -> str: + return "asyncio" + + +# --------------------------------------------------------------------------- +# Validation — no FalkorDB required (these raise before touching the graph) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("bad", ["", " ", "Foo.", "Bar. "]) +async def test_find_symbol_rejects_empty_name(bad): + from api.mcp.tools.structural import find_symbol + + with pytest.raises(ValueError, match="non-empty symbol name"): + await find_symbol(name=bad, project="any") + + +@pytest.mark.parametrize("bad", ["not-a-number", 1.5, None, True]) +async def test_find_symbol_rejects_garbage_limit(bad): + from api.mcp.tools.structural import find_symbol + + with pytest.raises(ValueError, match="limit"): + await find_symbol(name="entrypoint", project="any", limit=bad) + + +# --------------------------------------------------------------------------- +# Integration — sample_project fixture +# --------------------------------------------------------------------------- + + +async def test_find_symbol_resolves_simple_name(indexed_fixture): + from api.mcp.tools.structural import find_symbol + + rows = await find_symbol( + name="entrypoint", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + assert rows, "entrypoint must resolve" + r = rows[0] + assert isinstance(r["symbol_id"], int) + assert r["name"] == "entrypoint" + assert r["label"] == "Function" + assert r["file"].endswith("entrypoint.py") + assert isinstance(r["line"], int) + # file_match is part of the documented shape and present even without a + # ``file`` filter (False when none requested). + assert r["file_match"] is False + + +async def test_find_symbol_reduces_dotted_qualname(indexed_fixture): + """``Class.method`` must resolve identically to the bare ``method``.""" + from api.mcp.tools.structural import find_symbol + + dotted = await find_symbol( + name="BaseRepo.repo", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + simple = await find_symbol( + name="repo", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + assert {r["symbol_id"] for r in dotted} == {r["symbol_id"] for r in simple} + assert all(r["name"] == "repo" for r in dotted) + # Fixture defines repo() on BaseRepo, UserRepo and OrderRepo. + assert len(simple) == 3 + + +async def test_find_symbol_file_match_shape_stable(indexed_fixture): + """Every row carries a boolean ``file_match`` regardless of the args.""" + from api.mcp.tools.structural import find_symbol + + rows = await find_symbol( + name="repo", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + assert rows + for r in rows: + assert isinstance(r["file_match"], bool) + assert set(r) >= {"symbol_id", "name", "label", "file", "line", "file_match"} + + +async def test_find_symbol_file_filter_flags_matches(indexed_fixture): + """A matching ``file`` flags every in-file candidate; a non-matching one + still returns the global matches but flags none.""" + from api.mcp.tools.structural import find_symbol + + in_file = await find_symbol( + name="repo", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + file="repo.py", + ) + assert in_file and all(r["file_match"] for r in in_file) + + no_file = await find_symbol( + name="repo", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + file="does_not_exist.py", + ) + # Search is not silently widened away — global matches still surface… + assert {r["symbol_id"] for r in no_file} == {r["symbol_id"] for r in in_file} + # …but the agent can see none were in the requested file. + assert all(r["file_match"] is False for r in no_file) + + +async def test_find_symbol_deterministic_order(indexed_fixture): + """Repeated calls return rows in the same order (no FalkorDB row-order + dependence).""" + from api.mcp.tools.structural import find_symbol + + a = await find_symbol( + name="repo", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + b = await find_symbol( + name="repo", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + assert [r["symbol_id"] for r in a] == [r["symbol_id"] for r in b] + + +async def test_find_symbol_honors_limit(indexed_fixture): + from api.mcp.tools.structural import find_symbol + + rows = await find_symbol( + name="repo", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + limit="1", # stringified ints are accepted and coerced + ) + assert len(rows) == 1 + + +async def test_find_symbol_negative_limit_clamped_not_tail_slice(indexed_fixture): + """A negative ``limit`` must clamp to 1, NOT flip ``rows[:limit]`` into a + surprising tail slice.""" + from api.mcp.tools.structural import find_symbol + + full = await find_symbol( + name="repo", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + clamped = await find_symbol( + name="repo", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + limit=-2, + ) + assert len(clamped) == 1 + assert clamped[0]["symbol_id"] == full[0]["symbol_id"] + + +async def test_find_symbol_registered(): + from api.mcp.server import app + + names = {t.name for t in await app.list_tools()} + assert "find_symbol" in names + + +# --------------------------------------------------------------------------- +# Ordering across files — purpose-built graph (two same-named symbols) +# --------------------------------------------------------------------------- + + +@pytest.fixture +async def two_file_graph(): + """Two ``foo`` Functions in different files so ``file`` disambiguation can + actually reorder the result. Unique branch keeps it isolated; not torn down + (matches the ``indexed_fixture`` / ``cycle_graph`` pattern).""" + from api.graph import Graph + + project = "find_symbol_order_test" + branch = f"order-{uuid.uuid4().hex[:8]}" + g = Graph(project, branch=branch) + g.g.query( + """ + CREATE + (:Function:Searchable {name: 'foo', path: '/tmp/aaa.py', src_start: 1}), + (:Function:Searchable {name: 'foo', path: '/tmp/bbb.py', src_start: 1}) + """ + ) + yield project, branch + + +async def test_find_symbol_orders_in_file_first(two_file_graph): + from api.mcp.tools.structural import find_symbol + + project, branch = two_file_graph + + rows = await find_symbol(name="foo", project=project, branch=branch, file="bbb.py") + assert len(rows) == 2 + # The requested file sorts first and is the only flagged match. + assert rows[0]["file"].endswith("bbb.py") + assert rows[0]["file_match"] is True + assert rows[1]["file_match"] is False diff --git a/tests/mcp/test_impact_analysis.py b/tests/mcp/test_impact_analysis.py index 2a275273..a7414da2 100644 --- a/tests/mcp/test_impact_analysis.py +++ b/tests/mcp/test_impact_analysis.py @@ -69,17 +69,25 @@ async def test_impact_analysis_registered_via_app(): async def _find_id(indexed_fixture, name: str) -> int: - from api.mcp.tools.structural import search_code + """Resolve a symbol name to its int node id directly from the graph. - rows = await search_code( - prefix=name, - project=indexed_fixture.project, - branch=indexed_fixture.branch, - ) - for r in rows: - if r["name"] == name: - return r["id"] - raise AssertionError(f"symbol {name!r} not found") + ``search_code`` is file-oriented and no longer returns per-symbol ids. + """ + from api.mcp.tools.structural import _project_arg + + g = _project_arg(indexed_fixture.project, indexed_fixture.branch) + try: + res = await g._query( + "MATCH (n) WHERE (n:Function OR n:Class) AND n.name = $name " + "RETURN ID(n)", + {"name": name}, + ) + finally: + await g.close() + rows = res.result_set + assert rows, f"symbol {name!r} not found in graph" + assert len(rows) == 1, f"ambiguous symbol {name!r}: {len(rows)} matches" + return rows[0][0] async def test_impact_upstream_of_db(indexed_fixture, expected_contract): diff --git a/tests/mcp/test_query_tools.py b/tests/mcp/test_query_tools.py index c3d0981a..f0ff4744 100644 --- a/tests/mcp/test_query_tools.py +++ b/tests/mcp/test_query_tools.py @@ -28,21 +28,23 @@ def anyio_backend() -> str: async def test_search_code_finds_entrypoint(indexed_fixture, expected_contract): from api.mcp.tools.structural import search_code + # search_code is file-oriented: a free-text query naming a symbol must + # surface the file that defines it. + symbol = expected_contract["search_prefixes"]["ent"]["must_include"][0] results = await search_code( - prefix="ent", + query=symbol, project=indexed_fixture.project, branch=indexed_fixture.branch, ) - names = {r["name"] for r in results} - for required in expected_contract["search_prefixes"]["ent"]["must_include"]: - assert required in names, f"expected {required} in {names}" + files = [r["file"] for r in results if r.get("file")] + assert any(f.endswith(f"{symbol}.py") for f in files), files async def test_search_code_honors_limit(indexed_fixture): from api.mcp.tools.structural import search_code results = await search_code( - prefix="r", # broad prefix + query="entrypoint service repo db", # broad: matches several files project=indexed_fixture.project, branch=indexed_fixture.branch, limit=1, @@ -54,7 +56,7 @@ async def test_search_code_empty_for_nonsense(indexed_fixture): from api.mcp.tools.structural import search_code results = await search_code( - prefix="zzz_no_such_symbol_zzz", + query="zzz_no_such_symbol_zzz", project=indexed_fixture.project, branch=indexed_fixture.branch, ) @@ -65,31 +67,97 @@ async def test_search_code_result_serialisable(indexed_fixture): from api.mcp.tools.structural import search_code results = await search_code( - prefix="serv", + query="service", project=indexed_fixture.project, branch=indexed_fixture.branch, ) json.dumps(results) # must not raise -# --------------------------------------------------------------------------- -# get_callers / get_callees / get_dependencies (T5) -# --------------------------------------------------------------------------- +def test_relativize_strips_worktree_prefix(): + from api.mcp.tools.structural import _relativize + proj = "django__django-18854__loc" + abs_path = f"/Users/x/.worktrees/bench/worktrees/code_graph/{proj}/django/db/models/fields/__init__.py" + assert _relativize(abs_path, proj) == "django/db/models/fields/__init__.py" -async def _find_id(indexed_fixture, name: str) -> int: - """Helper: resolve a symbol name to its int node id via search_code.""" + +def test_relativize_keeps_nested_repeat_of_project_name(): + """A nested dir repeating the project name must survive (first-match strip).""" + from api.mcp.tools.structural import _relativize + + proj = "myproj" + abs_path = f"/root/{proj}/src/vendor/{proj}/file.py" + assert _relativize(abs_path, proj) == f"src/vendor/{proj}/file.py" + + +def test_relativize_noops_without_marker_or_rel_to(): + from api.mcp.tools.structural import _relativize + + assert _relativize("/abs/path/no/marker.py", "absent") == "/abs/path/no/marker.py" + assert _relativize("/abs/path.py", None) == "/abs/path.py" + assert _relativize(None, "proj") is None + + +async def test_search_code_returns_relative_paths(indexed_fixture): from api.mcp.tools.structural import search_code - rows = await search_code( - prefix=name, + results = await search_code( + query="entrypoint", project=indexed_fixture.project, branch=indexed_fixture.branch, ) - for r in rows: - if r["name"] == name: - return r["id"] - raise AssertionError(f"symbol {name!r} not found via search_code") + for r in results: + if r.get("file"): + assert not r["file"].startswith("/"), f"path not relativized: {r['file']}" + assert f"/{indexed_fixture.project}/" not in r["file"] + + +async def test_search_code_ranks_exact_match_within_limit(indexed_fixture, expected_contract): + """A query naming a symbol must surface that symbol's file as the top hit, + with the matching symbol as the file's representative.""" + from api.mcp.tools.structural import search_code + + symbol = next(iter(expected_contract["search_prefixes"]["ent"]["must_include"])) + results = await search_code( + query=symbol, + project=indexed_fixture.project, + branch=indexed_fixture.branch, + limit=1, + ) + assert results, f"no results for query {symbol!r}" + assert results[0]["file"].endswith(f"{symbol}.py") + assert results[0]["name"] == symbol + + +# --------------------------------------------------------------------------- +# get_callers / get_callees / get_dependencies (T5) +# --------------------------------------------------------------------------- + + +async def _find_id(indexed_fixture, name: str) -> int: + """Resolve a symbol name to its int node id directly from the graph. + + ``search_code`` is file-oriented and no longer returns per-symbol ids, so + the neighbor/path/impact tests resolve the id straight from FalkorDB. The + names used here (entrypoint/service/db) are unique Functions in the fixture, + so a uniqueness assertion guards against silently picking the wrong node. + """ + from api.mcp.tools.structural import _project_arg + + g = _project_arg(indexed_fixture.project, indexed_fixture.branch) + try: + res = await g._query( + "MATCH (n) WHERE (n:Function OR n:Class) AND n.name = $name " + "RETURN ID(n)", + {"name": name}, + ) + finally: + await g.close() + rows = res.result_set + assert rows, f"symbol {name!r} not found in graph" + assert len(rows) == 1, f"ambiguous symbol {name!r}: {len(rows)} matches" + return rows[0][0] async def test_get_callees_of_entrypoint(indexed_fixture, expected_contract): @@ -266,8 +334,138 @@ async def test_all_query_tools_registered(): tools = {t.name for t in await app.list_tools()} assert { "search_code", - "get_callers", - "get_callees", - "get_dependencies", + "get_neighbors", "find_path", }.issubset(tools) + # The per-edge tools were consolidated into get_neighbors and must no + # longer be advertised on the MCP surface. + assert tools.isdisjoint( + {"get_callers", "get_callees", "get_dependencies", "get_importers", "get_overrides"} + ) + + +async def test_get_neighbors_matches_legacy_callers_callees(indexed_fixture): + from api.mcp.tools.structural import get_callees, get_callers, get_neighbors + + project = indexed_fixture.project + branch = indexed_fixture.branch + entry_id = await _find_id(indexed_fixture, "entrypoint") + + legacy_callees = await get_callees(symbol_id=entry_id, project=project, branch=branch) + unified_callees = await get_neighbors( + symbol_id=entry_id, project=project, branch=branch, relation="CALLS", direction="OUT" + ) + assert {n["id"] for n in unified_callees} == {n["id"] for n in legacy_callees} + + service_id = await _find_id(indexed_fixture, "service") + legacy_callers = await get_callers(symbol_id=service_id, project=project, branch=branch) + unified_callers = await get_neighbors( + symbol_id=service_id, project=project, branch=branch, relation="CALLS", direction="IN" + ) + assert {n["id"] for n in unified_callers} == {n["id"] for n in legacy_callers} + + +async def test_get_neighbors_matches_legacy_dependencies(indexed_fixture): + from api.mcp.tools.structural import get_dependencies, get_neighbors + + project = indexed_fixture.project + branch = indexed_fixture.branch + entry_id = await _find_id(indexed_fixture, "entrypoint") + + legacy = await get_dependencies(symbol_id=entry_id, project=project, branch=branch) + unified = await get_neighbors( + symbol_id=entry_id, + project=project, + branch=branch, + relation=["IMPORTS", "CALLS", "DEFINES"], + direction="OUT", + ) + assert {n["id"] for n in unified} == {n["id"] for n in legacy} + + +# --------------------------------------------------------------------------- +# get_file_neighbors (T8b) — file-level structural coupling +# --------------------------------------------------------------------------- + + +async def test_get_file_neighbors_reaches_cross_file_dependency(indexed_fixture): + """File-level expansion must reach a cross-file dependency. + + ``entrypoint()`` calls ``service()`` in another file, so ``service.py`` + must surface as a neighbor — and crucially via expanding ALL symbols in + the file, not only the lowest-``src_start`` representative one. + """ + from api.mcp.tools.structural import get_file_neighbors + + res = await get_file_neighbors( + "python/entrypoint.py", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + assert isinstance(res, dict) + assert res["file"] == "python/entrypoint.py" + assert res["truncated"] is False + assert res["total_neighbors"] == len(res["neighbors"]) + + names = [n["file"] for n in res["neighbors"]] + assert any(n.endswith("service.py") for n in names), names + assert "python/entrypoint.py" not in names # self excluded + + for n in res["neighbors"]: + assert isinstance(n["file_id"], int) # recursable handle + assert n["edge_count"] >= 1 + assert n["relations"] # relation:direction breakdown present + + +async def test_get_file_neighbors_id_and_path_agree(indexed_fixture): + """Resolving by File-node id and by repo-relative path yields the same set.""" + from api.mcp.tools.structural import get_file_neighbors + + by_path = await get_file_neighbors( + "python/service.py", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + assert by_path["file_id"] is not None + by_id = await get_file_neighbors( + by_path["file_id"], + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + assert [n["file"] for n in by_id["neighbors"]] == [ + n["file"] for n in by_path["neighbors"] + ] + + +async def test_get_file_neighbors_unknown_file_is_empty(indexed_fixture): + from api.mcp.tools.structural import get_file_neighbors + + res = await get_file_neighbors( + "python/does_not_exist.py", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + assert res["total_neighbors"] == 0 + assert res["neighbors"] == [] + assert res["truncated"] is False + + +async def test_search_code_returns_file_id(indexed_fixture): + """Every search_code hit must carry a File-node ``file_id`` to hop from.""" + from api.mcp.tools.structural import search_code + + rows = await search_code( + query="service repo db", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + assert rows, "expected at least one hit" + for r in rows: + assert isinstance(r["file_id"], int) + + +async def test_get_file_neighbors_registered(): + from api.mcp.server import app + + tools = await app.list_tools() + assert "get_file_neighbors" in {t.name for t in tools} diff --git a/uv.lock b/uv.lock index de4f5a05..d9a2d15a 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.12, <3.14" [[package]] @@ -372,6 +372,9 @@ dependencies = [ ] [package.optional-dependencies] +light = [ + { name = "falkordblite" }, +] test = [ { name = "anyio" }, { name = "httpx" }, @@ -391,6 +394,7 @@ requires-dist = [ { name = "anyio", marker = "extra == 'test'", specifier = ">=4.0,<5.0" }, { name = "falkordb", specifier = ">=1.1.3,<2.0.0" }, { name = "falkordb-multilspy", specifier = ">=0.1.0,<1.0.0" }, + { name = "falkordblite", marker = "extra == 'light'", specifier = ">=0.10.0,<1.0.0" }, { name = "fastapi", specifier = ">=0.115.0,<1.0.0" }, { name = "graphrag-sdk", specifier = ">=0.8.1,<0.9.0" }, { name = "httpx", marker = "extra == 'test'", specifier = ">=0.28.0,<1.0.0" }, @@ -411,7 +415,7 @@ requires-dist = [ { name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0,<1.0.0" }, { name = "validators", specifier = ">=0.35.0,<0.36.0" }, ] -provides-extras = ["test"] +provides-extras = ["light", "test"] [package.metadata.requires-dev] dev = [ @@ -435,6 +439,26 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/15/97032c229a031b29795c2d0a74646bc24b1bdbb71930c44d1d37ef9d98c6/falkordb_multilspy-0.1.0-py3-none-any.whl", hash = "sha256:3429f11a83c4fbf06c2b8f0078b8a257e0870da9d50ebd964c6de2ad33164f57", size = 129555, upload-time = "2026-03-23T14:29:32.936Z" }, ] +[[package]] +name = "falkordblite" +version = "0.10.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "falkordb" }, + { name = "psutil" }, + { name = "redis" }, + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/70/03/afbb6e0f03c302aa9b64a38e1a2c43664f86921f0dcbf03ec9e31da06ac6/falkordblite-0.10.0.tar.gz", hash = "sha256:65a72abafd30711f699c15571df6959edb8901605053ce940ccdd837832e709b", size = 23675620, upload-time = "2026-05-02T13:12:29.429Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/43/39d8cf13964784447676d24f0cefa3bdc99c10e647c71e6a4172d302dcac/falkordblite-0.10.0-cp312-cp312-macosx_10_13_x86_64.macosx_15_0_arm64.whl", hash = "sha256:741fda166170513db1815d5369870e47d44da2e9b85320fddbc50c88a7338d51", size = 17752268, upload-time = "2026-05-02T13:11:57.906Z" }, + { url = "https://files.pythonhosted.org/packages/61/17/3ec180ca7c2e79a0c3e9912ac04b446e733745cd944845fe4306be016efd/falkordblite-0.10.0-cp312-cp312-manylinux_2_39_aarch64.whl", hash = "sha256:bfe1f47ae03decdc0ad111ec82606bac82ee6cdd7fea04cbcff1283608cc2d2e", size = 34579293, upload-time = "2026-05-02T13:12:01.601Z" }, + { url = "https://files.pythonhosted.org/packages/8a/5e/d343c5249bd24c6614e8c1b718a73bb9f68beb2cf90464bc51c9436d25af/falkordblite-0.10.0-cp312-cp312-manylinux_2_39_x86_64.whl", hash = "sha256:ee9659bd0c7cdf0c2532977f2ec8bf1d1ab01cb3b6776e50c67046481f3162c7", size = 36154066, upload-time = "2026-05-02T13:12:05.435Z" }, + { url = "https://files.pythonhosted.org/packages/f6/1e/18612dc9e75e5c02a281a49d97b56b5504eb5907f971c68e8a5801033f36/falkordblite-0.10.0-cp313-cp313-macosx_10_13_x86_64.macosx_15_0_arm64.whl", hash = "sha256:54a051cbc0bb25b30e65f46ddb1b63149a80bb7004c97615d39c491d492a6d9b", size = 17752240, upload-time = "2026-05-02T13:12:08.837Z" }, + { url = "https://files.pythonhosted.org/packages/29/b3/63f9b1168c4d1abbee4418a0a165496f57cd4b93a261cb481272ce353890/falkordblite-0.10.0-cp313-cp313-manylinux_2_39_aarch64.whl", hash = "sha256:ce1bd408ebaafc3dbdf39e860888382aae69de2b3c949dadfd132779130b99b2", size = 34579294, upload-time = "2026-05-02T13:12:12.069Z" }, + { url = "https://files.pythonhosted.org/packages/2e/cf/d792cf46292f7756f96727916e77ebf9821371e22bff65ed8e90db6b80b9/falkordblite-0.10.0-cp313-cp313-manylinux_2_39_x86_64.whl", hash = "sha256:7d52a3665f6de30cbffc5809a1c6d722492c95bf8dae4a7ee108ca58da939b25", size = 36154069, upload-time = "2026-05-02T13:12:15.742Z" }, +] + [[package]] name = "fastapi" version = "0.135.1" @@ -1615,6 +1639,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fe/4e/cd76eca6db6115604b7626668e891c9dd03330384082e33662fb0f113614/ruff-0.15.5-py3-none-win_arm64.whl", hash = "sha256:b498d1c60d2fe5c10c45ec3f698901065772730b411f164ae270bb6bfcc4740b", size = 10965572, upload-time = "2026-03-05T20:06:16.984Z" }, ] +[[package]] +name = "setuptools" +version = "82.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4f/db/cfac1baf10650ab4d1c111714410d2fbb77ac5a616db26775db562c8fab2/setuptools-82.0.1.tar.gz", hash = "sha256:7d872682c5d01cfde07da7bccc7b65469d3dca203318515ada1de5eda35efbf9", size = 1152316, upload-time = "2026-03-09T12:47:17.221Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl", hash = "sha256:a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb", size = 1006223, upload-time = "2026-03-09T12:47:15.026Z" }, +] + [[package]] name = "shellingham" version = "1.5.4"