Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions ARCHITECTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ question
┌──────────────────┐
│ prompts.py │ Render the chosen tables into the system prompt:
│ format_schema │ TABLE name, columns with PK/NOT NULL flags, FKs.
│ format_schema │ TABLE name, columns with PK/NOT NULL flags, column
│ │ comments, the legal values of every enum column, FKs.
└────────┬─────────┘
Expand Down Expand Up @@ -66,7 +67,12 @@ question
│ │ statement_timeout, so even if safety.py failed
│ │ the database itself would refuse a write.
└────────┬─────────┘
│ on a database error: repair.py feeds the failed SQL plus the
│ database's own error message back to the model for up to
│ --max-repair rounds (default 1). Every repaired query goes
│ through validate_select_only — and the confirm prompt, in the
│ REPL — before it is executed. Empty results never trigger a
│ repair: an empty result is often the right answer.
┌──────────────────┐
│ render.py │ Format as a rich.Table. NULLs styled, large blobs
Expand All @@ -86,6 +92,7 @@ question
| `retrieval.py` | Tokenizer, TF-IDF ranker, FK-graph expander. |
| `llm.py` | Provider clients (Anthropic, OpenAI), SQL extractor, provider factory. |
| `prompts.py` | System prompt template and schema-to-prompt formatter. |
| `repair.py` | Execution-guided repair: bounded retry loop fed by the database's errors. |
| `safety.py` | The sqlglot-based query guard. |
| `render.py` | SQL syntax rendering and result-table rendering with rich. |

Expand All @@ -96,12 +103,13 @@ question
- **`cli.py` ↔ everything else** — the only file that knows the full pipeline. New stages (query history, post-execution feedback) wire in here.
- **`safety.py` ↔ `llm.py`** — `extract_sql` runs first; `validate_select_only` runs second. Together they handle the case where the model returns malformed output.
- **`db.py` ↔ `safety.py`** — two layers, intentionally redundant. Either alone is insufficient; both together make a write impossible.
- **`repair.py` ↔ `safety.py`** — every repaired query is re-validated before execution. The repair loop widens what the model can fix; it never widens what can run.

## Design bets

### Why TF-IDF, not embeddings (yet)

TF-IDF works the moment you connect to a database. No model to download, no GPU, no API call to compute embeddings for hundreds of tables. The cost is that it cannot reason about synonyms — "customer" and "user" are different tokens to TF-IDF. That is the tradeoff v0.2 will revisit by adding embedding-based ranking as an optional layer on top.
TF-IDF works the moment you connect to a database. No model to download, no GPU, no API call to compute embeddings for hundreds of tables. The cost is that it cannot reason about synonyms — "customer" and "user" are different tokens to TF-IDF. Since 0.2 the LLM table-selector covers that gap (it sees the TF-IDF candidates and picks semantically), and measured retrieval recall on a 211-table benchmark sits at 98–100% without embeddings. Embeddings stay off the default path until the data shows a gap they would close.

### Why a separate FK-expansion pass

Expand All @@ -117,9 +125,9 @@ This redundancy is not paranoia. AI-generated SQL is, by construction, less pred

`pg_catalog` is faster, more complete, and lets us read table comments via `obj_description`. We use `LATERAL unnest WITH ORDINALITY` to join `pg_constraint.conkey` with `pg_constraint.confkey` by ordinal position — the only correct way to handle composite foreign keys.

## What's not in v0.1
## What's intentionally not here

These are intentionally out of scope for the MVP. They are tracked in the [roadmap](README.md#roadmap).
These are out of scope for now. They are tracked in the [roadmap](README.md#roadmap).

- MySQL / SQLite support — needs an adapter abstraction first.
- Multi-database sessions in one REPL.
Expand All @@ -136,7 +144,7 @@ Run the test suite:
pytest
```

All tests in v0.1 are pure Python — no live database required. The integration test harness (docker-compose + a real Postgres) is queued for v0.2 alongside the public benchmark suite against Spider / BIRD.
All tests are pure Python — no live database or API key required. The repair loop, prompt serialization, safety guard, retrieval, and CLI outcome logic are covered with in-memory fakes.

The most safety-critical file is `tests/test_safety.py`. Cases there encode "things the validator MUST reject." Add cases when you discover new attack vectors; do not delete cases during refactors.

Expand Down
29 changes: 29 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,34 @@
# Changelog

## 0.3.0 — 2026-06-10

Generation-quality release. The retrieval side has been measuring at 98–100% recall on a 211-table
benchmark for a while; the misses were in the SQL itself. This release attacks the three error
classes that benchmarking surfaced.

- **Enum-aware schema prompts.** The schema sent to the model now includes column comments and the
full legal value list of every enum column (read from `pg_catalog`, cached with the schema). The
single biggest failure class in our benchmarking was the model inventing enum values
(`'overdue'`, `'churned'`) or dodging a status column it couldn't see into timestamp guesses
(`delivered_at IS NOT NULL` instead of `status = 'delivered'`). Now it sees the real vocabulary.
(`src/promptquery/schema.py`, `src/promptquery/prompts.py`)
- **Execution-guided self-repair.** When the database rejects a generated query, the SQL plus the
database's own error message go back to the model for a corrected attempt — `--max-repair` rounds,
default 1, `0` disables. Repaired SQL passes through the same sqlglot safety validator (and the
REPL confirm prompt) before it ever runs. Empty results deliberately do *not* trigger repair: an
empty result is often the correct answer, and "fixing" it risks replacing a right answer with a
wrong one. (`src/promptquery/repair.py`, new)
- **Tighter generation rules.** The system prompt now pins down answer shape: return exactly the
columns asked for, no speculative filters (`deleted_at IS NULL` nobody asked about), INNER JOIN
unless the question implies otherwise, and filter state via the status column when one exists.
(`src/promptquery/prompts.py`)
- **Fix:** bare `o4-*` model names now infer the OpenAI provider, matching `o1`/`o3`
(`--model o4-mini` works without the `openai/` prefix). (`src/promptquery/llm.py`)
- **Tests** 48 → 66: repair-loop behavior (including "unsafe repairs never execute" and
"declined repairs never run"), enum serialization round-trips, provider inference.
- **Docs**: ARCHITECTURE.md pipeline updated for the repair stage and enum-aware prompts; stale
v0.1-era claims (embeddings "queued for v0.2", old test counts) reconciled with reality.

## 0.2.2 — 2026-06-04

- **Deterministic by default.** SQL generation now runs at `temperature = 0` (plus a fixed `seed` on
Expand Down
17 changes: 14 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ question
┌───────────────────┐
│ SQL generator │ Your real LLM call. Receives ~25 tables, not 675.
│ (frontier model) │
│ SQL generator │ Your real LLM call. Receives ~25 tables, not 675 —
│ (frontier model) │ with column comments and every enum's legal values,
│ │ so it filters on real states instead of guessing them.
└────────┬──────────┘
Expand All @@ -134,6 +135,15 @@ question
"Run? [y/N]" → execute against a read-only Postgres session
┌───────────────────┐
│ Self-repair │ If the database rejects the query, the error message
│ (on error only) │ goes back to the model for one corrected attempt
│ │ (--max-repair). Repaired SQL is re-validated and
│ │ re-confirmed before it runs. Empty results are never
│ │ "repaired" — empty is often the right answer.
└───────────────────┘
```

See [ARCHITECTURE.md](ARCHITECTURE.md) for the deep dive (file inventory, design bets, the patent-landmine non-goals).
Expand All @@ -149,6 +159,7 @@ See [ARCHITECTURE.md](ARCHITECTURE.md) for the deep dive (file inventory, design
| `--top-k` | 50 | TF-IDF candidates passed to the LLM selector |
| `--select` | 15 | Tables the LLM selector picks from those candidates |
| `--max-tables` | 25 | Cap after FK expansion — what the SQL generator actually sees |
| `--max-repair` | 1 | Repair rounds when the database rejects a query (0 disables) |
| `--no-selector` | — | Skip the LLM selector (v0.1 behaviour: TF-IDF + FK only) |
| `-y, --yes` | — | Skip the confirmation prompt before running |

Expand Down Expand Up @@ -255,7 +266,7 @@ python3.12 -m venv .venv
.venv/bin/python -m eval.retrieval
```

37 tests, all pure-Python — no live database or API key required for the core suite.
66 tests, all pure-Python — no live database or API key required for the core suite.

---

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "promptquery"
version = "0.2.2"
version = "0.3.0"
description = "Natural-language SQL for production-scale Postgres schemas"
readme = "README.md"
license = { text = "Apache-2.0" }
Expand Down
2 changes: 1 addition & 1 deletion src/promptquery/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.2"
__version__ = "0.3.0"
55 changes: 41 additions & 14 deletions src/promptquery/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .llm import LLMClient, LLMError, extract_sql, make_client
from .prompts import build_system_prompt
from .render import render_results, render_sql
from .repair import execute_with_repair
from .retrieval import TfIdfRetriever, expand_via_fks, llm_select_tables
from .safety import UnsafeQuery, validate_select_only
from .schema import Schema, introspect
Expand Down Expand Up @@ -87,6 +88,7 @@ def run_question(
select_n: int,
max_tables: int,
confirm: bool,
max_repair: int = 1,
progress: Console | None = None,
prompt_for_confirm=None,
) -> QueryResult:
Expand Down Expand Up @@ -139,12 +141,30 @@ def run_question(
if answer not in {"y", "yes"}:
return QueryResult(Outcome.SKIPPED, sql, [], [], "user declined")

try:
cols, rows = db.execute(sql)
except Exception as e:
return QueryResult(Outcome.EXEC_ERROR, sql, [], [], str(e))
def _confirm_repaired(repaired_sql: str) -> bool:
if not confirm:
return True
if progress:
render_sql(progress, repaired_sql)
if prompt_for_confirm is None:
return False
return prompt_for_confirm().strip().lower() in {"y", "yes"}

result = execute_with_repair(
db, llm, system_prompt, question, sql,
max_repair=max_repair,
confirm_cb=_confirm_repaired if confirm else None,
progress_cb=(lambda msg: progress.print(f"[yellow]{msg}[/yellow]")) if progress else None,
)
if result.declined:
return QueryResult(Outcome.SKIPPED, result.sql, [], [], "user declined")
if result.error is not None:
return QueryResult(Outcome.EXEC_ERROR, result.sql, [], [], result.error)
if result.attempts and progress and not confirm:
# Show the repaired SQL that actually ran (confirm mode already displayed it).
render_sql(progress, result.sql)

return QueryResult(Outcome.OK, sql, cols, rows)
return QueryResult(Outcome.OK, result.sql, result.cols, result.rows)


# -- CLI -------------------------------------------------------------------
Expand Down Expand Up @@ -194,6 +214,13 @@ def run_question(
show_default=True,
help="Maximum tables sent to the LLM after FK expansion.",
)
@click.option(
"--max-repair",
default=1,
show_default=True,
help="Repair rounds when a query fails: feed the database error back to the model "
"and retry. 0 disables repair.",
)
@click.option(
"--no-selector",
is_flag=True,
Expand All @@ -208,7 +235,7 @@ def run_question(
@click.version_option(__version__, prog_name="promptquery")
def main(dsn: str, model: str | None, selector_model: str | None,
query: str | None, out_format: str | None,
top_k: int, select_n: int, max_tables: int,
top_k: int, select_n: int, max_tables: int, max_repair: int,
no_selector: bool, yes: bool) -> None:
"""PromptQuery — natural-language SQL for Postgres.

Expand Down Expand Up @@ -276,27 +303,27 @@ def main(dsn: str, model: str | None, selector_model: str | None,
question=query, schema=schema, retriever=retriever,
llm=llm, selector_llm=selector_llm, db=db_ctx,
top_k=top_k, select_n=select_n, max_tables=max_tables,
out_fmt=out_fmt, progress=progress,
max_repair=max_repair, out_fmt=out_fmt, progress=progress,
))

_run_repl(
schema=schema, retriever=retriever,
llm=llm, selector_llm=selector_llm, db=db_ctx,
top_k=top_k, select_n=select_n, max_tables=max_tables,
out_fmt=out_fmt, yes=yes, progress=progress,
max_repair=max_repair, out_fmt=out_fmt, yes=yes, progress=progress,
)
finally:
db_ctx.close()


def _run_one_shot(*, question, schema, retriever, llm, selector_llm, db,
top_k, select_n, max_tables, out_fmt: OutputFormat,
progress: Console) -> int:
top_k, select_n, max_tables, max_repair,
out_fmt: OutputFormat, progress: Console) -> int:
result = run_question(
question=question, schema=schema, retriever=retriever,
llm=llm, selector_llm=selector_llm, db=db,
top_k=top_k, select_n=select_n, max_tables=max_tables,
confirm=False, progress=progress,
max_repair=max_repair, confirm=False, progress=progress,
)

if result.outcome is Outcome.LLM_ERROR:
Expand Down Expand Up @@ -325,8 +352,8 @@ def _run_one_shot(*, question, schema, retriever, llm, selector_llm, db,


def _run_repl(*, schema, retriever, llm, selector_llm, db,
top_k, select_n, max_tables, out_fmt: OutputFormat,
yes: bool, progress: Console) -> None:
top_k, select_n, max_tables, max_repair,
out_fmt: OutputFormat, yes: bool, progress: Console) -> None:
session: PromptSession[str] = PromptSession(history=InMemoryHistory())
progress.print(
"\n[bold]PromptQuery[/bold] — ask a question in plain English, "
Expand All @@ -348,7 +375,7 @@ def _run_repl(*, schema, retriever, llm, selector_llm, db,
question=question, schema=schema, retriever=retriever,
llm=llm, selector_llm=selector_llm, db=db,
top_k=top_k, select_n=select_n, max_tables=max_tables,
confirm=not yes, progress=progress,
max_repair=max_repair, confirm=not yes, progress=progress,
prompt_for_confirm=(lambda: session.prompt("Run? [y/N] ")) if not yes else None,
)

Expand Down
2 changes: 1 addition & 1 deletion src/promptquery/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def make_client(model_spec: str | None = None) -> LLMClient:
# No explicit provider — guess from model name
if provider.startswith("claude"):
return AnthropicClient(model=provider)
if provider.startswith("gpt") or provider.startswith("o1") or provider.startswith("o3"):
if provider.startswith(("gpt", "o1", "o3", "o4")):
return OpenAIClient(model=provider)
raise LLMError(
f"Cannot infer provider from model {provider!r}. "
Expand Down
21 changes: 18 additions & 3 deletions src/promptquery/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,18 @@
2. Use ONLY the tables and columns listed in the schema below. Do not invent columns.
3. NEVER write INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, TRUNCATE, GRANT, COPY, or any DDL/DML.
4. Prefer explicit JOINs over implicit ones. Always qualify columns when joining.
5. If the schema below is insufficient to answer the question, output a single SELECT that
5. SELECT exactly the columns the question asks for — no extra id, timestamp, or
metadata columns added "to be helpful".
6. Apply only the filters the question states. Do not add speculative conditions
(e.g. deleted_at IS NULL, paid_at IS NOT NULL) the question didn't ask for.
7. When a table has a status-like enum column, filter on it using ONLY the values listed
for that column in the schema. Never invent enum values, and don't infer state from
timestamps when a status column exists.
8. Use INNER JOIN unless the question asks to keep rows without a match.
9. If the schema below is insufficient to answer the question, output a single SELECT that
returns an error message string (e.g. SELECT 'insufficient schema: missing X' AS error)
rather than guessing.
6. Output ONLY the SQL inside a ```sql code block. No prose, no explanation.
10. Output ONLY the SQL inside a ```sql code block. No prose, no explanation.

Available schema:
{schema}
Expand All @@ -34,7 +42,14 @@ def format_schema(tables: list[Table]) -> str:
if not c.nullable:
modifiers.append("NOT NULL")
mod_str = f" [{', '.join(modifiers)}]" if modifiers else ""
lines.append(f" {c.name} {c.data_type}{mod_str}")
notes = []
if c.comment:
notes.append(c.comment)
if c.enum_values:
values = ", ".join(f"'{v}'" for v in c.enum_values)
notes.append(f"one of: {values}")
note_str = f" -- {'; '.join(notes)}" if notes else ""
lines.append(f" {c.name} {c.data_type}{mod_str}{note_str}")
for fk in t.foreign_keys:
ref = fk.referenced_table
if fk.referenced_schema and fk.referenced_schema != "public":
Expand Down
Loading