From 858800a943420a074f802f85e92c98c72211a080 Mon Sep 17 00:00:00 2001 From: Filippo Menghi Date: Wed, 10 Jun 2026 09:15:11 +0200 Subject: [PATCH 1/4] =?UTF-8?q?feat:=20enum-aware=20schema=20prompts=20?= =?UTF-8?q?=E2=80=94=20legal=20enum=20values=20+=20column=20comments=20in?= =?UTF-8?q?=20the=20prompt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introspection now reads col_description() and pg_enum labels for every column; format_schema renders them, so the generator filters on real states instead of guessing them. New generation rules pin down answer shape: exactly the columns asked for, no speculative filters, INNER JOIN by default, status columns over timestamp inference. --- src/promptquery/prompts.py | 21 ++++++++++-- src/promptquery/schema.py | 20 +++++++++++- tests/test_prompts_enums.py | 65 +++++++++++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 4 deletions(-) create mode 100644 tests/test_prompts_enums.py diff --git a/src/promptquery/prompts.py b/src/promptquery/prompts.py index f0cb3cf..5ded81f 100644 --- a/src/promptquery/prompts.py +++ b/src/promptquery/prompts.py @@ -11,10 +11,18 @@ 2. Use ONLY the tables and columns listed in the schema below. Do not invent columns. 3. NEVER write INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, TRUNCATE, GRANT, COPY, or any DDL/DML. 4. Prefer explicit JOINs over implicit ones. Always qualify columns when joining. -5. If the schema below is insufficient to answer the question, output a single SELECT that +5. SELECT exactly the columns the question asks for — no extra id, timestamp, or + metadata columns added "to be helpful". +6. Apply only the filters the question states. Do not add speculative conditions + (e.g. deleted_at IS NULL, paid_at IS NOT NULL) the question didn't ask for. +7. When a table has a status-like enum column, filter on it using ONLY the values listed + for that column in the schema. Never invent enum values, and don't infer state from + timestamps when a status column exists. +8. Use INNER JOIN unless the question asks to keep rows without a match. +9. If the schema below is insufficient to answer the question, output a single SELECT that returns an error message string (e.g. SELECT 'insufficient schema: missing X' AS error) rather than guessing. -6. Output ONLY the SQL inside a ```sql code block. No prose, no explanation. +10. Output ONLY the SQL inside a ```sql code block. No prose, no explanation. Available schema: {schema} @@ -34,7 +42,14 @@ def format_schema(tables: list[Table]) -> str: if not c.nullable: modifiers.append("NOT NULL") mod_str = f" [{', '.join(modifiers)}]" if modifiers else "" - lines.append(f" {c.name} {c.data_type}{mod_str}") + notes = [] + if c.comment: + notes.append(c.comment) + if c.enum_values: + values = ", ".join(f"'{v}'" for v in c.enum_values) + notes.append(f"one of: {values}") + note_str = f" -- {'; '.join(notes)}" if notes else "" + lines.append(f" {c.name} {c.data_type}{mod_str}{note_str}") for fk in t.foreign_keys: ref = fk.referenced_table if fk.referenced_schema and fk.referenced_schema != "public": diff --git a/src/promptquery/schema.py b/src/promptquery/schema.py index 9258611..9e526c7 100644 --- a/src/promptquery/schema.py +++ b/src/promptquery/schema.py @@ -13,6 +13,10 @@ class Column: data_type: str nullable: bool is_primary_key: bool + comment: str | None = None + # For enum-typed columns: every legal label, in enum order. The generator must see + # these — an enum value it can't see is an enum value it will invent. + enum_values: tuple[str, ...] | None = None def to_dict(self) -> dict: return { @@ -20,15 +24,20 @@ def to_dict(self) -> dict: "data_type": self.data_type, "nullable": self.nullable, "is_primary_key": self.is_primary_key, + "comment": self.comment, + "enum_values": list(self.enum_values) if self.enum_values else None, } @classmethod def from_dict(cls, data: dict) -> "Column": + enum_values = data.get("enum_values") return cls( name=data["name"], data_type=data["data_type"], nullable=bool(data["nullable"]), is_primary_key=bool(data["is_primary_key"]), + comment=data.get("comment"), + enum_values=tuple(enum_values) if enum_values else None, ) @@ -130,10 +139,16 @@ def from_dict(cls, data: dict) -> "Schema": WHERE i.indrelid = c.oid AND i.indisprimary AND a.attnum = ANY(i.indkey) - ) AS is_primary_key + ) AS is_primary_key, + col_description(c.oid, a.attnum) AS comment, + CASE WHEN t.typtype = 'e' THEN ( + SELECT array_agg(e.enumlabel ORDER BY e.enumsortorder) + FROM pg_enum e WHERE e.enumtypid = t.oid + ) END AS enum_values FROM pg_attribute a JOIN pg_class c ON c.oid = a.attrelid JOIN pg_namespace n ON n.oid = c.relnamespace +JOIN pg_type t ON t.oid = a.atttypid WHERE a.attnum > 0 AND NOT a.attisdropped AND c.relkind IN ('r', 'v', 'm', 'p', 'f') @@ -185,11 +200,14 @@ def introspect(db: "Database") -> Schema: table = tables.get(key) if table is None: continue + enum_values = row.get("enum_values") table.columns.append(Column( name=row["name"], data_type=row["data_type"], nullable=bool(row["nullable"]), is_primary_key=bool(row["is_primary_key"]), + comment=row.get("comment"), + enum_values=tuple(enum_values) if enum_values else None, )) for row in fks_rows: diff --git a/tests/test_prompts_enums.py b/tests/test_prompts_enums.py new file mode 100644 index 0000000..0ae8089 --- /dev/null +++ b/tests/test_prompts_enums.py @@ -0,0 +1,65 @@ +"""Tests for enum-aware schema serialization (schema.Column + prompts.format_schema).""" +from __future__ import annotations + +from promptquery.prompts import SYSTEM_PROMPT, format_schema +from promptquery.schema import Column, Table + + +def _orders_table() -> Table: + return Table( + schema="sales", + name="orders", + comment="customer orders", + columns=[ + Column("id", "bigint", False, True), + Column( + "status", "order_status", False, False, + comment="current order state", + enum_values=("pending", "paid", "shipped", "refunded"), + ), + Column("placed_at", "timestamptz", True, False), + ], + ) + + +def test_format_schema_lists_enum_values(): + out = format_schema([_orders_table()]) + assert "one of: 'pending', 'paid', 'shipped', 'refunded'" in out + + +def test_format_schema_renders_column_comment(): + out = format_schema([_orders_table()]) + assert "current order state" in out + + +def test_format_schema_plain_column_unchanged(): + out = format_schema([_orders_table()]) + assert "placed_at timestamptz" in out + # No note suffix on a column without comment or enum: + line = next(ln for ln in out.splitlines() if ln.strip().startswith("placed_at")) + assert "--" not in line + + +def test_column_roundtrip_preserves_enum_values_and_comment(): + c = Column( + "status", "order_status", False, False, + comment="state", enum_values=("a", "b"), + ) + restored = Column.from_dict(c.to_dict()) + assert restored == c + + +def test_column_from_dict_tolerates_old_payloads(): + # Payloads serialized before 0.3.0 carry neither field. + c = Column.from_dict( + {"name": "x", "data_type": "int", "nullable": True, "is_primary_key": False} + ) + assert c.comment is None + assert c.enum_values is None + + +def test_system_prompt_carries_answer_shape_rules(): + assert "exactly the columns the question asks for" in SYSTEM_PROMPT + assert "Never invent enum values" in SYSTEM_PROMPT + assert "INNER JOIN" in SYSTEM_PROMPT + assert "speculative conditions" in SYSTEM_PROMPT From 6d388fb781838f592fa02927fd7b645d39676cfe Mon Sep 17 00:00:00 2001 From: Filippo Menghi Date: Wed, 10 Jun 2026 09:15:11 +0200 Subject: [PATCH 2/4] feat: execution-guided self-repair (--max-repair, default 1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the database rejects a query, feed the SQL plus the database's own error back to the model for a bounded number of corrected attempts. Repaired SQL is re-validated by the sqlglot guard and re-confirmed in the REPL before it runs. Empty results never trigger repair — empty is often the right answer. --- src/promptquery/cli.py | 55 ++++++++++++----- src/promptquery/repair.py | 96 ++++++++++++++++++++++++++++ tests/test_repair.py | 127 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 264 insertions(+), 14 deletions(-) create mode 100644 src/promptquery/repair.py create mode 100644 tests/test_repair.py diff --git a/src/promptquery/cli.py b/src/promptquery/cli.py index 2644c15..bb1bdf8 100644 --- a/src/promptquery/cli.py +++ b/src/promptquery/cli.py @@ -18,6 +18,7 @@ from .llm import LLMClient, LLMError, extract_sql, make_client from .prompts import build_system_prompt from .render import render_results, render_sql +from .repair import execute_with_repair from .retrieval import TfIdfRetriever, expand_via_fks, llm_select_tables from .safety import UnsafeQuery, validate_select_only from .schema import Schema, introspect @@ -87,6 +88,7 @@ def run_question( select_n: int, max_tables: int, confirm: bool, + max_repair: int = 1, progress: Console | None = None, prompt_for_confirm=None, ) -> QueryResult: @@ -139,12 +141,30 @@ def run_question( if answer not in {"y", "yes"}: return QueryResult(Outcome.SKIPPED, sql, [], [], "user declined") - try: - cols, rows = db.execute(sql) - except Exception as e: - return QueryResult(Outcome.EXEC_ERROR, sql, [], [], str(e)) + def _confirm_repaired(repaired_sql: str) -> bool: + if not confirm: + return True + if progress: + render_sql(progress, repaired_sql) + if prompt_for_confirm is None: + return False + return prompt_for_confirm().strip().lower() in {"y", "yes"} + + result = execute_with_repair( + db, llm, system_prompt, question, sql, + max_repair=max_repair, + confirm_cb=_confirm_repaired if confirm else None, + progress_cb=(lambda msg: progress.print(f"[yellow]{msg}[/yellow]")) if progress else None, + ) + if result.declined: + return QueryResult(Outcome.SKIPPED, result.sql, [], [], "user declined") + if result.error is not None: + return QueryResult(Outcome.EXEC_ERROR, result.sql, [], [], result.error) + if result.attempts and progress and not confirm: + # Show the repaired SQL that actually ran (confirm mode already displayed it). + render_sql(progress, result.sql) - return QueryResult(Outcome.OK, sql, cols, rows) + return QueryResult(Outcome.OK, result.sql, result.cols, result.rows) # -- CLI ------------------------------------------------------------------- @@ -194,6 +214,13 @@ def run_question( show_default=True, help="Maximum tables sent to the LLM after FK expansion.", ) +@click.option( + "--max-repair", + default=1, + show_default=True, + help="Repair rounds when a query fails: feed the database error back to the model " + "and retry. 0 disables repair.", +) @click.option( "--no-selector", is_flag=True, @@ -208,7 +235,7 @@ def run_question( @click.version_option(__version__, prog_name="promptquery") def main(dsn: str, model: str | None, selector_model: str | None, query: str | None, out_format: str | None, - top_k: int, select_n: int, max_tables: int, + top_k: int, select_n: int, max_tables: int, max_repair: int, no_selector: bool, yes: bool) -> None: """PromptQuery — natural-language SQL for Postgres. @@ -276,27 +303,27 @@ def main(dsn: str, model: str | None, selector_model: str | None, question=query, schema=schema, retriever=retriever, llm=llm, selector_llm=selector_llm, db=db_ctx, top_k=top_k, select_n=select_n, max_tables=max_tables, - out_fmt=out_fmt, progress=progress, + max_repair=max_repair, out_fmt=out_fmt, progress=progress, )) _run_repl( schema=schema, retriever=retriever, llm=llm, selector_llm=selector_llm, db=db_ctx, top_k=top_k, select_n=select_n, max_tables=max_tables, - out_fmt=out_fmt, yes=yes, progress=progress, + max_repair=max_repair, out_fmt=out_fmt, yes=yes, progress=progress, ) finally: db_ctx.close() def _run_one_shot(*, question, schema, retriever, llm, selector_llm, db, - top_k, select_n, max_tables, out_fmt: OutputFormat, - progress: Console) -> int: + top_k, select_n, max_tables, max_repair, + out_fmt: OutputFormat, progress: Console) -> int: result = run_question( question=question, schema=schema, retriever=retriever, llm=llm, selector_llm=selector_llm, db=db, top_k=top_k, select_n=select_n, max_tables=max_tables, - confirm=False, progress=progress, + max_repair=max_repair, confirm=False, progress=progress, ) if result.outcome is Outcome.LLM_ERROR: @@ -325,8 +352,8 @@ def _run_one_shot(*, question, schema, retriever, llm, selector_llm, db, def _run_repl(*, schema, retriever, llm, selector_llm, db, - top_k, select_n, max_tables, out_fmt: OutputFormat, - yes: bool, progress: Console) -> None: + top_k, select_n, max_tables, max_repair, + out_fmt: OutputFormat, yes: bool, progress: Console) -> None: session: PromptSession[str] = PromptSession(history=InMemoryHistory()) progress.print( "\n[bold]PromptQuery[/bold] — ask a question in plain English, " @@ -348,7 +375,7 @@ def _run_repl(*, schema, retriever, llm, selector_llm, db, question=question, schema=schema, retriever=retriever, llm=llm, selector_llm=selector_llm, db=db, top_k=top_k, select_n=select_n, max_tables=max_tables, - confirm=not yes, progress=progress, + max_repair=max_repair, confirm=not yes, progress=progress, prompt_for_confirm=(lambda: session.prompt("Run? [y/N] ")) if not yes else None, ) diff --git a/src/promptquery/repair.py b/src/promptquery/repair.py new file mode 100644 index 0000000..f71c504 --- /dev/null +++ b/src/promptquery/repair.py @@ -0,0 +1,96 @@ +"""Execution-guided repair: run the generated SQL and, if the database rejects it, +feed the database's own error back to the model for a bounded number of repair rounds. + +Only hard execution errors trigger a repair. Empty results don't — an empty result is +often the correct answer, and "fixing" it risks replacing a right answer with a wrong one. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Callable + +from .llm import LLMClient, extract_sql +from .safety import UnsafeQuery, validate_select_only + +if TYPE_CHECKING: + from .db import Database + + +REPAIR_PROMPT = """\ +The query below failed when run against the database. Fix it and return a corrected +query, following all the rules above (including the schema's enum value lists). + +Question: {question} + +Failed SQL: +{sql} + +Database error: +{error} + +Output ONLY the corrected SQL inside a ```sql code block.""" + + +@dataclass +class RepairResult: + sql: str # the SQL that was last attempted + cols: list[str] = field(default_factory=list) + rows: list[tuple] = field(default_factory=list) + error: str | None = None # None = executed successfully + attempts: int = 0 # repair rounds actually used + declined: bool = False # user declined a repaired query (REPL confirm) + + +def execute_with_repair( + db: "Database", + llm: LLMClient, + system_prompt: str, + question: str, + sql: str, + *, + max_repair: int = 1, + confirm_cb: Callable[[str], bool] | None = None, + progress_cb: Callable[[str], None] | None = None, +) -> RepairResult: + """Execute `sql`; on a database error, ask the model to repair it (≤ max_repair rounds). + + Repaired SQL goes through validate_select_only before it is ever executed — the + safety layer applies to every attempt, not just the first. If a repair fails safety, + comes back empty, or the model errors, the original failure is returned unchanged. + `confirm_cb` (REPL confirm mode) is asked before any repaired query runs. + """ + attempts = 0 + current = sql + while True: + try: + cols, rows = db.execute(current) + return RepairResult(current, cols, rows, None, attempts) + except Exception as e: # psycopg errors carry the message we want to feed back + error = str(e) + + if attempts >= max_repair: + return RepairResult(current, error=error, attempts=attempts) + attempts += 1 + if progress_cb: + progress_cb(f"query failed ({error.splitlines()[0]}); repairing " + f"[{attempts}/{max_repair}]") + + try: + raw = llm.generate( + system_prompt, + REPAIR_PROMPT.format(question=question, sql=current, error=error), + ) + except Exception: + return RepairResult(current, error=error, attempts=attempts) + + repaired = extract_sql(raw) + if not repaired or repaired == current: + return RepairResult(current, error=error, attempts=attempts) + try: + validate_select_only(repaired) + except UnsafeQuery: + return RepairResult(current, error=error, attempts=attempts) + + if confirm_cb is not None and not confirm_cb(repaired): + return RepairResult(repaired, error=error, attempts=attempts, declined=True) + current = repaired diff --git a/tests/test_repair.py b/tests/test_repair.py new file mode 100644 index 0000000..a043889 --- /dev/null +++ b/tests/test_repair.py @@ -0,0 +1,127 @@ +"""Tests for the execution-guided repair loop (promptquery.repair).""" +from __future__ import annotations + +from promptquery.repair import REPAIR_PROMPT, execute_with_repair + + +class _ScriptedLLM: + """Returns canned responses in order; records every prompt it was sent.""" + name = "fake" + model = "fake-1" + + def __init__(self, responses: list[str]): + self._responses = list(responses) + self.calls: list[tuple[str, str]] = [] + + def generate(self, system: str, user: str) -> str: + self.calls.append((system, user)) + if not self._responses: + raise RuntimeError("no scripted responses left") + return self._responses.pop(0) + + +class _FlakyDB: + """Raises for SQL in `failing`, succeeds otherwise.""" + + def __init__(self, failing: dict[str, str], rows=None): + self._failing = failing + self._rows = rows or [(1,)] + self.executed: list[str] = [] + + def execute(self, sql: str): + self.executed.append(sql) + if sql in self._failing: + raise RuntimeError(self._failing[sql]) + return (["count"], self._rows) + + +def test_no_repair_needed_executes_once(): + db = _FlakyDB(failing={}) + llm = _ScriptedLLM([]) + result = execute_with_repair(db, llm, "sys", "q", "SELECT 1", max_repair=1) + assert result.error is None + assert result.attempts == 0 + assert result.sql == "SELECT 1" + assert llm.calls == [] # no LLM round when the query just works + + +def test_repairs_on_db_error_and_succeeds(): + bad = "SELECT count(*) FROM t WHERE status = 'overdue'" + good = "SELECT count(*) FROM t WHERE status = 'past_due'" + db = _FlakyDB(failing={bad: 'invalid input value for enum invoice_status: "overdue"'}) + llm = _ScriptedLLM([f"```sql\n{good}\n```"]) + + result = execute_with_repair(db, llm, "sys", "overdue invoices?", bad, max_repair=1) + + assert result.error is None + assert result.attempts == 1 + assert result.sql == good + assert db.executed == [bad, good] + # The repair prompt carried the failed SQL and the DB's own error message. + _, user = llm.calls[0] + assert bad in user and "invalid input value" in user + + +def test_gives_up_after_max_repair(): + bad = "SELECT broken" + db = _FlakyDB(failing={bad: "syntax error", "SELECT also_broken": "still broken"}) + llm = _ScriptedLLM(["```sql\nSELECT also_broken\n```"]) + + result = execute_with_repair(db, llm, "sys", "q", bad, max_repair=1) + + assert result.error == "still broken" + assert result.attempts == 1 + + +def test_max_repair_zero_disables_repair(): + bad = "SELECT broken" + db = _FlakyDB(failing={bad: "syntax error"}) + llm = _ScriptedLLM(["```sql\nSELECT 1\n```"]) + + result = execute_with_repair(db, llm, "sys", "q", bad, max_repair=0) + + assert result.error == "syntax error" + assert result.attempts == 0 + assert llm.calls == [] + + +def test_unsafe_repair_is_never_executed(): + bad = "SELECT broken" + db = _FlakyDB(failing={bad: "syntax error"}) + llm = _ScriptedLLM(["```sql\nDELETE FROM t\n```"]) + + result = execute_with_repair(db, llm, "sys", "q", bad, max_repair=1) + + assert result.error == "syntax error" # original failure, unchanged + assert db.executed == [bad] # the DELETE never reached the database + + +def test_identical_repair_bails_out(): + bad = "SELECT broken" + db = _FlakyDB(failing={bad: "syntax error"}) + llm = _ScriptedLLM([f"```sql\n{bad}\n```"]) + + result = execute_with_repair(db, llm, "sys", "q", bad, max_repair=3) + + assert result.error == "syntax error" + assert db.executed == [bad] # no point re-running the same statement + + +def test_confirm_callback_can_decline_repaired_sql(): + bad = "SELECT broken" + good = "SELECT 1" + db = _FlakyDB(failing={bad: "syntax error"}) + llm = _ScriptedLLM([f"```sql\n{good}\n```"]) + + result = execute_with_repair( + db, llm, "sys", "q", bad, max_repair=1, confirm_cb=lambda sql: False, + ) + + assert result.declined is True + assert db.executed == [bad] # declined repair never ran + + +def test_repair_prompt_mentions_enum_rule(): + # The repair message nudges the model back to the schema's enum lists — the + # dominant hard-error class the loop exists to fix. + assert "enum" in REPAIR_PROMPT From a2a99a5286761066609ce7adb17f92e30d931473 Mon Sep 17 00:00:00 2001 From: Filippo Menghi Date: Wed, 10 Jun 2026 09:15:11 +0200 Subject: [PATCH 3/4] fix: infer OpenAI provider for bare o4-* model names --- src/promptquery/llm.py | 2 +- tests/test_make_client.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 tests/test_make_client.py diff --git a/src/promptquery/llm.py b/src/promptquery/llm.py index de30afa..761a483 100644 --- a/src/promptquery/llm.py +++ b/src/promptquery/llm.py @@ -103,7 +103,7 @@ def make_client(model_spec: str | None = None) -> LLMClient: # No explicit provider — guess from model name if provider.startswith("claude"): return AnthropicClient(model=provider) - if provider.startswith("gpt") or provider.startswith("o1") or provider.startswith("o3"): + if provider.startswith(("gpt", "o1", "o3", "o4")): return OpenAIClient(model=provider) raise LLMError( f"Cannot infer provider from model {provider!r}. " diff --git a/tests/test_make_client.py b/tests/test_make_client.py new file mode 100644 index 0000000..f5580f1 --- /dev/null +++ b/tests/test_make_client.py @@ -0,0 +1,27 @@ +"""Tests for make_client provider inference (llm.make_client).""" +from __future__ import annotations + +import pytest + +from promptquery.llm import LLMError, make_client + + +def test_bare_o4_model_infers_openai(monkeypatch): + pytest.importorskip("openai") + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + client = make_client("o4-mini") + assert client.name == "openai" + assert client.model == "o4-mini" + + +def test_bare_gpt_model_infers_openai(monkeypatch): + pytest.importorskip("openai") + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + client = make_client("gpt-4o") + assert client.name == "openai" + assert client.model == "gpt-4o" + + +def test_unknown_bare_model_still_raises(): + with pytest.raises(LLMError, match="Cannot infer provider"): + make_client("mistral-large") From 29fa4a6d87fe2b3d184476cb59379ddcc9e5e7bb Mon Sep 17 00:00:00 2001 From: Filippo Menghi Date: Wed, 10 Jun 2026 09:15:11 +0200 Subject: [PATCH 4/4] =?UTF-8?q?chore:=200.3.0=20=E2=80=94=20changelog,=20p?= =?UTF-8?q?ipeline=20docs,=20reconcile=20stale=20v0.1-era=20claims?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ARCHITECTURE.md | 20 ++++++++++++++------ CHANGELOG.md | 29 +++++++++++++++++++++++++++++ README.md | 17 ++++++++++++++--- pyproject.toml | 2 +- src/promptquery/__init__.py | 2 +- 5 files changed, 59 insertions(+), 11 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index b74848d..d169138 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -26,7 +26,8 @@ question ▼ ┌──────────────────┐ │ prompts.py │ Render the chosen tables into the system prompt: -│ format_schema │ TABLE name, columns with PK/NOT NULL flags, FKs. +│ format_schema │ TABLE name, columns with PK/NOT NULL flags, column +│ │ comments, the legal values of every enum column, FKs. └────────┬─────────┘ │ ▼ @@ -66,7 +67,12 @@ question │ │ statement_timeout, so even if safety.py failed │ │ the database itself would refuse a write. └────────┬─────────┘ - │ + │ on a database error: repair.py feeds the failed SQL plus the + │ database's own error message back to the model for up to + │ --max-repair rounds (default 1). Every repaired query goes + │ through validate_select_only — and the confirm prompt, in the + │ REPL — before it is executed. Empty results never trigger a + │ repair: an empty result is often the right answer. ▼ ┌──────────────────┐ │ render.py │ Format as a rich.Table. NULLs styled, large blobs @@ -86,6 +92,7 @@ question | `retrieval.py` | Tokenizer, TF-IDF ranker, FK-graph expander. | | `llm.py` | Provider clients (Anthropic, OpenAI), SQL extractor, provider factory. | | `prompts.py` | System prompt template and schema-to-prompt formatter. | +| `repair.py` | Execution-guided repair: bounded retry loop fed by the database's errors. | | `safety.py` | The sqlglot-based query guard. | | `render.py` | SQL syntax rendering and result-table rendering with rich. | @@ -96,12 +103,13 @@ question - **`cli.py` ↔ everything else** — the only file that knows the full pipeline. New stages (query history, post-execution feedback) wire in here. - **`safety.py` ↔ `llm.py`** — `extract_sql` runs first; `validate_select_only` runs second. Together they handle the case where the model returns malformed output. - **`db.py` ↔ `safety.py`** — two layers, intentionally redundant. Either alone is insufficient; both together make a write impossible. +- **`repair.py` ↔ `safety.py`** — every repaired query is re-validated before execution. The repair loop widens what the model can fix; it never widens what can run. ## Design bets ### Why TF-IDF, not embeddings (yet) -TF-IDF works the moment you connect to a database. No model to download, no GPU, no API call to compute embeddings for hundreds of tables. The cost is that it cannot reason about synonyms — "customer" and "user" are different tokens to TF-IDF. That is the tradeoff v0.2 will revisit by adding embedding-based ranking as an optional layer on top. +TF-IDF works the moment you connect to a database. No model to download, no GPU, no API call to compute embeddings for hundreds of tables. The cost is that it cannot reason about synonyms — "customer" and "user" are different tokens to TF-IDF. Since 0.2 the LLM table-selector covers that gap (it sees the TF-IDF candidates and picks semantically), and measured retrieval recall on a 211-table benchmark sits at 98–100% without embeddings. Embeddings stay off the default path until the data shows a gap they would close. ### Why a separate FK-expansion pass @@ -117,9 +125,9 @@ This redundancy is not paranoia. AI-generated SQL is, by construction, less pred `pg_catalog` is faster, more complete, and lets us read table comments via `obj_description`. We use `LATERAL unnest WITH ORDINALITY` to join `pg_constraint.conkey` with `pg_constraint.confkey` by ordinal position — the only correct way to handle composite foreign keys. -## What's not in v0.1 +## What's intentionally not here -These are intentionally out of scope for the MVP. They are tracked in the [roadmap](README.md#roadmap). +These are out of scope for now. They are tracked in the [roadmap](README.md#roadmap). - MySQL / SQLite support — needs an adapter abstraction first. - Multi-database sessions in one REPL. @@ -136,7 +144,7 @@ Run the test suite: pytest ``` -All tests in v0.1 are pure Python — no live database required. The integration test harness (docker-compose + a real Postgres) is queued for v0.2 alongside the public benchmark suite against Spider / BIRD. +All tests are pure Python — no live database or API key required. The repair loop, prompt serialization, safety guard, retrieval, and CLI outcome logic are covered with in-memory fakes. The most safety-critical file is `tests/test_safety.py`. Cases there encode "things the validator MUST reject." Add cases when you discover new attack vectors; do not delete cases during refactors. diff --git a/CHANGELOG.md b/CHANGELOG.md index a887b23..299bffe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,34 @@ # Changelog +## 0.3.0 — 2026-06-10 + +Generation-quality release. The retrieval side has been measuring at 98–100% recall on a 211-table +benchmark for a while; the misses were in the SQL itself. This release attacks the three error +classes that benchmarking surfaced. + +- **Enum-aware schema prompts.** The schema sent to the model now includes column comments and the + full legal value list of every enum column (read from `pg_catalog`, cached with the schema). The + single biggest failure class in our benchmarking was the model inventing enum values + (`'overdue'`, `'churned'`) or dodging a status column it couldn't see into timestamp guesses + (`delivered_at IS NOT NULL` instead of `status = 'delivered'`). Now it sees the real vocabulary. + (`src/promptquery/schema.py`, `src/promptquery/prompts.py`) +- **Execution-guided self-repair.** When the database rejects a generated query, the SQL plus the + database's own error message go back to the model for a corrected attempt — `--max-repair` rounds, + default 1, `0` disables. Repaired SQL passes through the same sqlglot safety validator (and the + REPL confirm prompt) before it ever runs. Empty results deliberately do *not* trigger repair: an + empty result is often the correct answer, and "fixing" it risks replacing a right answer with a + wrong one. (`src/promptquery/repair.py`, new) +- **Tighter generation rules.** The system prompt now pins down answer shape: return exactly the + columns asked for, no speculative filters (`deleted_at IS NULL` nobody asked about), INNER JOIN + unless the question implies otherwise, and filter state via the status column when one exists. + (`src/promptquery/prompts.py`) +- **Fix:** bare `o4-*` model names now infer the OpenAI provider, matching `o1`/`o3` + (`--model o4-mini` works without the `openai/` prefix). (`src/promptquery/llm.py`) +- **Tests** 48 → 66: repair-loop behavior (including "unsafe repairs never execute" and + "declined repairs never run"), enum serialization round-trips, provider inference. +- **Docs**: ARCHITECTURE.md pipeline updated for the repair stage and enum-aware prompts; stale + v0.1-era claims (embeddings "queued for v0.2", old test counts) reconciled with reality. + ## 0.2.2 — 2026-06-04 - **Deterministic by default.** SQL generation now runs at `temperature = 0` (plus a fixed `seed` on diff --git a/README.md b/README.md index 7d59d92..423dc4e 100644 --- a/README.md +++ b/README.md @@ -122,8 +122,9 @@ question │ ▼ ┌───────────────────┐ -│ SQL generator │ Your real LLM call. Receives ~25 tables, not 675. -│ (frontier model) │ +│ SQL generator │ Your real LLM call. Receives ~25 tables, not 675 — +│ (frontier model) │ with column comments and every enum's legal values, +│ │ so it filters on real states instead of guessing them. └────────┬──────────┘ │ ▼ @@ -134,6 +135,15 @@ question │ ▼ "Run? [y/N]" → execute against a read-only Postgres session + │ + ▼ +┌───────────────────┐ +│ Self-repair │ If the database rejects the query, the error message +│ (on error only) │ goes back to the model for one corrected attempt +│ │ (--max-repair). Repaired SQL is re-validated and +│ │ re-confirmed before it runs. Empty results are never +│ │ "repaired" — empty is often the right answer. +└───────────────────┘ ``` See [ARCHITECTURE.md](ARCHITECTURE.md) for the deep dive (file inventory, design bets, the patent-landmine non-goals). @@ -149,6 +159,7 @@ See [ARCHITECTURE.md](ARCHITECTURE.md) for the deep dive (file inventory, design | `--top-k` | 50 | TF-IDF candidates passed to the LLM selector | | `--select` | 15 | Tables the LLM selector picks from those candidates | | `--max-tables` | 25 | Cap after FK expansion — what the SQL generator actually sees | +| `--max-repair` | 1 | Repair rounds when the database rejects a query (0 disables) | | `--no-selector` | — | Skip the LLM selector (v0.1 behaviour: TF-IDF + FK only) | | `-y, --yes` | — | Skip the confirmation prompt before running | @@ -255,7 +266,7 @@ python3.12 -m venv .venv .venv/bin/python -m eval.retrieval ``` -37 tests, all pure-Python — no live database or API key required for the core suite. +66 tests, all pure-Python — no live database or API key required for the core suite. --- diff --git a/pyproject.toml b/pyproject.toml index beca4d9..145bc3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "promptquery" -version = "0.2.2" +version = "0.3.0" description = "Natural-language SQL for production-scale Postgres schemas" readme = "README.md" license = { text = "Apache-2.0" } diff --git a/src/promptquery/__init__.py b/src/promptquery/__init__.py index b5fdc75..493f741 100644 --- a/src/promptquery/__init__.py +++ b/src/promptquery/__init__.py @@ -1 +1 @@ -__version__ = "0.2.2" +__version__ = "0.3.0"