diff --git a/CHANGELOG.md b/CHANGELOG.md index f47a176..97a7938 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,19 @@ # Changelog +## [Unreleased] + +### Added +- **Decorator extraction in parser**: Functions and classes now store decorators/annotations in `node.extra["decorators"]` (Python, Java/Kotlin/C#, TypeScript) +- **Expanded framework decorator patterns**: Dead code detection and entry point discovery now recognize pytest fixtures, Django signals, SQLAlchemy events, Spring annotations, Celery tasks, NestJS/Angular decorators, pydantic-ai agent tools +- **Type annotation reference tracking**: Classes referenced in function parameter types or return types (e.g. Pydantic schemas) are no longer flagged as dead code +- **Per-symbol IMPORTS_FROM edges**: JS/TS/TSX named imports (`import { A, B } from './mod'`) now create edges targeting individual functions/classes, not just the file -- eliminates ~320 FPs from frontend codebases +- **ORM/framework base class exclusion**: Classes inheriting from known framework bases (Base, DeclarativeBase, BaseModel, BaseSettings, etc.) are no longer flagged as dead code + +### Fixed +- **Dead code false positives**: Dunder methods (`__init__`, `__str__`, etc.) excluded from dead code results -- they are runtime-invoked and never have explicit callers +- **Dead code false positives**: Decorated entry points (e.g. `@app.get`, `@pytest.fixture`) now correctly excluded via parser-populated decorator metadata +- **Dead code false positives**: Alembic `upgrade`/`downgrade` and FastAPI `lifespan`/`get_db` recognized as entry points + ## [2.1.0] - 2026-04-03 ### Added diff --git a/CLAUDE.md b/CLAUDE.md index cb82907..b2bb98d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -38,7 +38,7 @@ ```bash # Development -uv run pytest tests/ --tb=short -q # Run tests (572 tests) +uv run pytest tests/ --tb=short -q # Run tests (609 tests) uv run ruff check code_review_graph/ # Lint uv run mypy code_review_graph/ --ignore-missing-imports --no-strict-optional diff --git a/code_review_graph/flows.py b/code_review_graph/flows.py index 8c7374c..d722c96 100644 --- a/code_review_graph/flows.py +++ b/code_review_graph/flows.py @@ -25,14 +25,32 @@ # Decorator patterns that indicate a function is a framework entry point. _FRAMEWORK_DECORATOR_PATTERNS: list[re.Pattern[str]] = [ - re.compile(r"app\.(get|post|put|delete|patch|route|websocket)", re.IGNORECASE), + # Python web frameworks + re.compile(r"app\.(get|post|put|delete|patch|route|websocket|on_event)", re.IGNORECASE), re.compile(r"router\.(get|post|put|delete|patch|route)", re.IGNORECASE), re.compile(r"blueprint\.(route|before_request|after_request)", re.IGNORECASE), + re.compile(r"(before|after)_(request|response)", re.IGNORECASE), + # CLI frameworks re.compile(r"click\.(command|group)", re.IGNORECASE), - re.compile(r"celery\.(task|shared_task)", re.IGNORECASE), + # Task queues + re.compile(r"(celery\.)?(task|shared_task|periodic_task)", re.IGNORECASE), + # Django + re.compile(r"receiver", re.IGNORECASE), re.compile(r"api_view", re.IGNORECASE), re.compile(r"\baction\b", re.IGNORECASE), - re.compile(r"@(Get|Post|Put|Delete|Patch|RequestMapping)", re.IGNORECASE), + # Testing + re.compile(r"pytest\.(fixture|mark)"), + re.compile(r"(override_settings|modify_settings)", re.IGNORECASE), + # SQLAlchemy / event systems + re.compile(r"(event\.)?listens_for", re.IGNORECASE), + # Java Spring + re.compile(r"(Get|Post|Put|Delete|Patch|RequestMapping)Mapping", re.IGNORECASE), + re.compile(r"(Scheduled|EventListener|Bean|Configuration)", re.IGNORECASE), + # JS/TS frameworks + re.compile(r"(Component|Injectable|Controller|Module|Guard|Pipe)", re.IGNORECASE), + re.compile(r"(Subscribe|Mutation|Query|Resolver)", re.IGNORECASE), + # AI/agent frameworks (pydantic-ai, langchain, etc.) + re.compile(r"\w+\.tool\b", re.IGNORECASE), ] # Name patterns that indicate conventional entry points. @@ -43,6 +61,12 @@ re.compile(r"^Test[A-Z]"), re.compile(r"^on_"), re.compile(r"^handle_"), + # Alembic migration entry points + re.compile(r"^upgrade$"), + re.compile(r"^downgrade$"), + # FastAPI lifecycle / dependency injection + re.compile(r"^lifespan$"), + re.compile(r"^get_db$"), ] diff --git a/code_review_graph/parser.py b/code_review_graph/parser.py index bded99f..95078f7 100644 --- a/code_review_graph/parser.py +++ b/code_review_graph/parser.py @@ -1456,6 +1456,40 @@ def _extract_js_field_function( ) return True + @staticmethod + def _extract_decorators(child) -> list[str]: + """Extract decorator/annotation names from a definition node. + + Handles Python (decorated_definition parent), Java/Kotlin/C# + (annotation in modifiers child), and TypeScript (decorator child). + """ + decorators: list[str] = [] + + # Python: parent is decorated_definition wrapping the definition + parent = child.parent + if parent and parent.type == "decorated_definition": + for sibling in parent.children: + if sibling.type == "decorator": + text = sibling.text.decode("utf-8", errors="replace") + decorators.append(text.lstrip("@").strip()) + return decorators + + # Java/Kotlin/C#: annotations inside a modifiers child + for sub in child.children: + if sub.type == "modifiers": + for mod in sub.children: + if mod.type in ("annotation", "marker_annotation"): + text = mod.text.decode("utf-8", errors="replace") + decorators.append(text.lstrip("@").strip()) + + # TypeScript: decorator children directly on class/method node + for sub in child.children: + if sub.type == "decorator": + text = sub.text.decode("utf-8", errors="replace") + decorators.append(text.lstrip("@").strip()) + + return decorators + def _extract_classes( self, child, @@ -1477,6 +1511,7 @@ def _extract_classes( if not name: return False + decorators = self._extract_decorators(child) node = NodeInfo( kind="Class", name=name, @@ -1485,6 +1520,7 @@ def _extract_classes( line_end=child.end_point[0] + 1, language=language, parent_name=enclosing_class, + extra={"decorators": decorators} if decorators else {}, ) nodes.append(node) @@ -1545,6 +1581,7 @@ def _extract_functions( qualified = self._qualify(name, file_path, enclosing_class) params = self._get_params(child, language, source) ret_type = self._get_return_type(child, language, source) + decorators = self._extract_decorators(child) node = NodeInfo( kind=kind, @@ -1557,6 +1594,7 @@ def _extract_functions( params=params, return_type=ret_type, is_test=is_test, + extra={"decorators": decorators} if decorators else {}, ) nodes.append(node) @@ -1614,14 +1652,56 @@ def _extract_imports( resolved = self._resolve_module_to_file( imp_target, file_path, language, ) + target = resolved if resolved else imp_target edges.append(EdgeInfo( kind="IMPORTS_FROM", source=file_path, - target=resolved if resolved else imp_target, + target=target, file_path=file_path, line=child.start_point[0] + 1, )) + # Per-symbol IMPORTS_FROM edges for JS/TS/TSX named imports. + # This lets dead-code detection see that individual functions/ + # classes in the source file are referenced by importers. + if resolved and language in ("javascript", "typescript", "tsx"): + for name in self._get_js_import_names(child): + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path, + target=f"{resolved}::{name}", + file_path=file_path, + line=child.start_point[0] + 1, + )) + + @staticmethod + def _get_js_import_names(node) -> list[str]: + """Extract imported symbol names from a JS/TS import statement. + + For ``import { A, B as C } from './mod'``, returns ``["A", "B"]`` + (original export names, not local aliases). For default imports + like ``import D from './mod'``, returns ``["D"]``. + """ + names: list[str] = [] + for child in node.children: + if child.type == "import_clause": + for sub in child.children: + if sub.type == "identifier": + # Default import + names.append(sub.text.decode("utf-8", errors="replace")) + elif sub.type == "named_imports": + for spec in sub.children: + if spec.type == "import_specifier": + idents = [ + s.text.decode("utf-8", errors="replace") + for s in spec.children + if s.type in ("identifier", "property_identifier") + ] + # First identifier is the original name + if idents: + names.append(idents[0]) + return names + def _extract_calls( self, child, diff --git a/code_review_graph/refactor.py b/code_review_graph/refactor.py index 3dc1026..ca2e47d 100644 --- a/code_review_graph/refactor.py +++ b/code_review_graph/refactor.py @@ -9,6 +9,7 @@ from __future__ import annotations import logging +import re import threading import time import uuid @@ -20,6 +21,14 @@ logger = logging.getLogger(__name__) +# Base class names that indicate a framework-managed class (ORM models, +# Pydantic schemas, settings). Classes inheriting from these are invoked +# via metaclass/framework magic and should not be flagged as dead code. +_FRAMEWORK_BASE_CLASSES = frozenset({ + "Base", "DeclarativeBase", "Model", "BaseModel", "BaseSettings", + "db.Model", "TableBase", +}) + # --------------------------------------------------------------------------- # Thread-safe pending refactors storage # --------------------------------------------------------------------------- @@ -173,6 +182,22 @@ def _is_entry_point(node: Any) -> bool: return False +# Matches identifiers inside type annotations (e.g. "GoalCreate" in +# "body: GoalCreate", "Optional[UserResponse]", "list[Item]"). +_TYPE_IDENT_RE = re.compile(r"[A-Z][A-Za-z0-9_]*") + + +def _collect_type_referenced_names(store: GraphStore) -> set[str]: + """Collect class names that appear in function params or return types.""" + funcs = store.get_nodes_by_kind(kinds=["Function", "Test"]) + names: set[str] = set() + for f in funcs: + for text in (f.params, f.return_type): + if text: + names.update(_TYPE_IDENT_RE.findall(text)) + return names + + def find_dead_code( store: GraphStore, kind: Optional[str] = None, @@ -197,6 +222,9 @@ def find_dead_code( file_pattern=file_pattern, ) + # Build set of class names referenced in function type annotations. + type_ref_names = _collect_type_referenced_names(store) + dead: list[dict[str, Any]] = [] for node in candidates: @@ -205,10 +233,28 @@ def find_dead_code( if node.is_test: continue + # Skip dunder methods -- invoked by runtime, never have explicit callers. + if node.name.startswith("__") and node.name.endswith("__"): + continue + # Skip entry points (by name pattern or decorator, not just "uncalled"). if _is_entry_point(node): continue + # Skip classes referenced in type annotations (Pydantic schemas, etc.). + if node.kind == "Class" and node.name in type_ref_names: + continue + + # Skip classes inheriting from known framework bases (ORM models, etc.). + if node.kind == "Class": + outgoing = store.get_edges_by_source(node.qualified_name) + base_names = { + e.target_qualified.rsplit("::", 1)[-1] + for e in outgoing if e.kind == "INHERITS" + } + if base_names & _FRAMEWORK_BASE_CLASSES: + continue + # Check for callers (CALLS), test refs (TESTED_BY), importers (IMPORTS_FROM). incoming = store.get_edges_by_target(node.qualified_name) has_callers = any(e.kind == "CALLS" for e in incoming) diff --git a/tests/test_flows.py b/tests/test_flows.py index f7743b4..54c325a 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -108,6 +108,61 @@ def test_detect_entry_points_name_pattern(self): assert "handle_request" in ep_names assert "regular_func" not in ep_names + # --------------------------------------------------------------- + # detect_entry_points -- expanded decorator patterns + # --------------------------------------------------------------- + + def test_detect_entry_points_pytest_fixture(self): + """pytest.fixture decorator marks function as entry point.""" + self._add_func("my_fixture", extra={"decorators": ["pytest.fixture"]}) + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "my_fixture" in ep_names + + def test_detect_entry_points_django_receiver(self): + """Django signal receiver decorator marks function as entry point.""" + self._add_func("on_save", extra={"decorators": ["receiver(post_save)"]}) + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "on_save" in ep_names + + def test_detect_entry_points_spring_scheduled(self): + """Java Spring @Scheduled marks function as entry point.""" + self._add_func("cleanup_job", extra={"decorators": ["Scheduled(cron='0 0 * * *')"]}) + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "cleanup_job" in ep_names + + def test_detect_entry_points_celery_task(self): + """Bare @task decorator marks function as entry point.""" + self._add_func("process_data", extra={"decorators": ["task"]}) + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "process_data" in ep_names + + def test_detect_entry_points_agent_tool(self): + """@agent.tool decorator marks function as entry point.""" + self._add_func("query_health", extra={"decorators": ["health_agent.tool"]}) + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "query_health" in ep_names + + def test_detect_entry_points_alembic(self): + """upgrade/downgrade functions are entry points.""" + self._add_func("upgrade") + self._add_func("downgrade") + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "upgrade" in ep_names + assert "downgrade" in ep_names + + def test_detect_entry_points_lifespan(self): + """FastAPI lifespan function is an entry point.""" + self._add_func("lifespan") + eps = detect_entry_points(self.store) + ep_names = {ep.name for ep in eps} + assert "lifespan" in ep_names + # --------------------------------------------------------------- # trace_flows # --------------------------------------------------------------- diff --git a/tests/test_parser.py b/tests/test_parser.py index 79f4a95..9fece89 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -440,3 +440,147 @@ def test_non_test_file_describe_not_special(self): ) finally: tmp_path.unlink(missing_ok=True) + + def test_python_decorator_extraction(self): + """Decorated Python functions should have decorators in extra.""" + import tempfile + code = b"""\ +from fastapi import APIRouter + +router = APIRouter() + +@router.get("/users") +def get_users(): + return [] + +@router.post("/users") +@some_validator +def create_user(body): + pass + +def plain_func(): + pass +""" + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as f: + f.write(code) + tmp_path = Path(f.name) + try: + nodes, _ = self.parser.parse_file(tmp_path) + funcs = {n.name: n for n in nodes if n.kind == "Function"} + + assert "get_users" in funcs + assert funcs["get_users"].extra.get("decorators") == [ + 'router.get("/users")', + ] + + assert "create_user" in funcs + decos = funcs["create_user"].extra.get("decorators") + assert len(decos) == 2 + assert 'router.post("/users")' in decos + assert "some_validator" in decos + + assert "plain_func" in funcs + assert not funcs["plain_func"].extra.get("decorators") + finally: + tmp_path.unlink(missing_ok=True) + + def test_python_class_decorator_extraction(self): + """Decorated Python classes should have decorators in extra.""" + import tempfile + code = b"""\ +import dataclasses + +@dataclasses.dataclass +class MyModel: + name: str + +class PlainClass: + pass +""" + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as f: + f.write(code) + tmp_path = Path(f.name) + try: + nodes, _ = self.parser.parse_file(tmp_path) + classes = {n.name: n for n in nodes if n.kind == "Class"} + + assert "MyModel" in classes + assert classes["MyModel"].extra.get("decorators") == [ + "dataclasses.dataclass", + ] + + assert "PlainClass" in classes + assert not classes["PlainClass"].extra.get("decorators") + finally: + tmp_path.unlink(missing_ok=True) + + def test_tsx_named_import_creates_per_symbol_edges(self): + """import { A, B } from './mod' should create per-symbol IMPORTS_FROM edges.""" + import tempfile + tmp_dir = Path(tempfile.mkdtemp()) + try: + # Source module with exported functions + mod_file = tmp_dir / "mod.ts" + mod_file.write_bytes(b"export function getUsers() { return []; }\n" + b"export function getItems() { return []; }\n") + + # Importer + importer = tmp_dir / "page.tsx" + importer.write_bytes( + b"import { getUsers, getItems } from './mod';\n" + b"export function Page() { return getUsers(); }\n" + ) + + nodes, edges = self.parser.parse_file(importer) + import_edges = [e for e in edges if e.kind == "IMPORTS_FROM"] + + targets = {e.target for e in import_edges} + resolved_mod = str(mod_file.resolve()) + # File-level edge + assert resolved_mod in targets + # Per-symbol edges + assert f"{resolved_mod}::getUsers" in targets + assert f"{resolved_mod}::getItems" in targets + finally: + import shutil + shutil.rmtree(tmp_dir) + + def test_tsx_default_import_creates_per_symbol_edge(self): + """import Foo from './mod' should create a per-symbol IMPORTS_FROM edge.""" + import tempfile + tmp_dir = Path(tempfile.mkdtemp()) + try: + mod_file = tmp_dir / "mod.ts" + mod_file.write_bytes(b"export default function Foo() {}\n") + + importer = tmp_dir / "app.tsx" + importer.write_bytes(b"import Foo from './mod';\n") + + nodes, edges = self.parser.parse_file(importer) + import_edges = [e for e in edges if e.kind == "IMPORTS_FROM"] + targets = {e.target for e in import_edges} + resolved_mod = str(mod_file.resolve()) + assert f"{resolved_mod}::Foo" in targets + finally: + import shutil + shutil.rmtree(tmp_dir) + + def test_tsx_aliased_import_uses_original_name(self): + """import { A as B } should create edge to ::A (original name).""" + import tempfile + tmp_dir = Path(tempfile.mkdtemp()) + try: + mod_file = tmp_dir / "util.ts" + mod_file.write_bytes(b"export function helper() {}\n") + + importer = tmp_dir / "main.tsx" + importer.write_bytes(b"import { helper as h } from './util';\n") + + nodes, edges = self.parser.parse_file(importer) + import_edges = [e for e in edges if e.kind == "IMPORTS_FROM"] + targets = {e.target for e in import_edges} + resolved_mod = str(mod_file.resolve()) + assert f"{resolved_mod}::helper" in targets + finally: + import shutil + shutil.rmtree(tmp_dir) diff --git a/tests/test_refactor.py b/tests/test_refactor.py index 2ee55c4..bf83fc0 100644 --- a/tests/test_refactor.py +++ b/tests/test_refactor.py @@ -191,6 +191,121 @@ def test_find_dead_code_file_pattern(self): dead = find_dead_code(self.store, file_pattern="nonexistent") assert len(dead) == 0 + def test_find_dead_code_excludes_dunder(self): + """Dunder methods are not flagged as dead code.""" + self.store.upsert_node(NodeInfo( + kind="Function", name="__init__", file_path="/repo/app.py", + line_start=90, line_end=95, language="python", + parent_name="MyClass", + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "__init__" not in dead_names + + def test_find_dead_code_excludes_decorated_entry(self): + """Functions with framework decorators are not flagged as dead code.""" + self.store.upsert_node(NodeInfo( + kind="Function", name="get_users", file_path="/repo/app.py", + line_start=90, line_end=95, language="python", + extra={"decorators": ["app.get('/users')"]}, + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "get_users" not in dead_names + + def test_find_dead_code_excludes_type_referenced_class(self): + """Classes referenced in function type annotations are not dead code.""" + self.store.upsert_node(NodeInfo( + kind="Class", name="UserSchema", file_path="/repo/app.py", + line_start=5, line_end=15, language="python", + )) + # A function that uses UserSchema in its params + self.store.upsert_node(NodeInfo( + kind="Function", name="create_user", file_path="/repo/app.py", + line_start=20, line_end=30, language="python", + params="body: UserSchema", + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "UserSchema" not in dead_names + + def test_find_dead_code_excludes_return_type_reference(self): + """Classes referenced in return types are not dead code.""" + self.store.upsert_node(NodeInfo( + kind="Class", name="UserResponse", file_path="/repo/app.py", + line_start=5, line_end=15, language="python", + )) + self.store.upsert_node(NodeInfo( + kind="Function", name="get_user", file_path="/repo/app.py", + line_start=20, line_end=30, language="python", + return_type="Optional[UserResponse]", + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "UserResponse" not in dead_names + + def test_find_dead_code_excludes_orm_model(self): + """Classes inheriting from known ORM bases are not dead code.""" + self.store.upsert_node(NodeInfo( + kind="Class", name="User", file_path="/repo/app.py", + line_start=5, line_end=20, language="python", + )) + self.store.upsert_edge(EdgeInfo( + kind="INHERITS", source="/repo/app.py::User", + target="Base", file_path="/repo/app.py", line=5, + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "User" not in dead_names + + def test_find_dead_code_excludes_pydantic_settings(self): + """Classes inheriting from BaseSettings are not dead code.""" + self.store.upsert_node(NodeInfo( + kind="Class", name="AppConfig", file_path="/repo/app.py", + line_start=5, line_end=15, language="python", + )) + self.store.upsert_edge(EdgeInfo( + kind="INHERITS", source="/repo/app.py::AppConfig", + target="BaseSettings", file_path="/repo/app.py", line=5, + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "AppConfig" not in dead_names + + def test_find_dead_code_excludes_agent_tool(self): + """Functions with @agent.tool decorator are not dead code.""" + self.store.upsert_node(NodeInfo( + kind="Function", name="query_data", file_path="/repo/app.py", + line_start=10, line_end=20, language="python", + extra={"decorators": ["health_agent.tool"]}, + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "query_data" not in dead_names + + def test_find_dead_code_excludes_alembic_upgrade(self): + """upgrade() and downgrade() in alembic files are not dead code.""" + self.store.upsert_node(NodeInfo( + kind="Function", name="upgrade", file_path="/repo/alembic/versions/001.py", + line_start=5, line_end=15, language="python", + )) + self.store.upsert_node(NodeInfo( + kind="Function", name="downgrade", file_path="/repo/alembic/versions/001.py", + line_start=20, line_end=30, language="python", + )) + self.store.commit() + dead = find_dead_code(self.store) + dead_names = {d["name"] for d in dead} + assert "upgrade" not in dead_names + assert "downgrade" not in dead_names + class TestSuggestRefactorings: """Tests for suggest_refactorings."""