From d06ff168afef1b5826e2ded56227449e6a2d4fde Mon Sep 17 00:00:00 2001 From: Wondr Date: Fri, 5 Jun 2026 06:46:04 +0100 Subject: [PATCH] feat: add sqlite schema adapter --- ARCHITECTURE.md | 23 +++--- README.md | 18 ++--- src/promptquery/cli.py | 16 +++-- src/promptquery/db.py | 90 ++++++++++++++++++++++++ src/promptquery/safety.py | 5 +- src/promptquery/schema.py | 57 ++++++++++++++- tests/test_safety.py | 14 ++++ tests/test_schema_adapters.py | 127 ++++++++++++++++++++++++++++++++++ 8 files changed, 319 insertions(+), 31 deletions(-) create mode 100644 tests/test_schema_adapters.py diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index b74848d..03ad6c7 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -42,7 +42,7 @@ question │ ▼ ┌──────────────────┐ -│ safety.py │ Parse with sqlglot in Postgres dialect. Reject anything +│ safety.py │ Parse with sqlglot in the database dialect. Reject anything │ validate_select │ that is not a single SELECT / WITH / UNION / │ _only │ INTERSECT / EXCEPT. Reject CTEs that hide DML │ │ (WITH x AS (DELETE ... RETURNING ...) SELECT ...). @@ -61,10 +61,10 @@ question │ ▼ ┌──────────────────┐ -│ db.py │ Execute. The Postgres session was opened with -│ Database.execute │ default_transaction_read_only = on and a 60s -│ │ statement_timeout, so even if safety.py failed -│ │ the database itself would refuse a write. +│ db.py │ Execute. Postgres sessions use +│ Database.execute │ default_transaction_read_only = on; SQLite files open +│ │ mode=ro and use PRAGMA query_only = ON, so even if +│ │ safety.py failed the database itself refuses writes. └────────┬─────────┘ │ ▼ @@ -81,8 +81,8 @@ question | `__init__.py` | Package version. | | `__main__.py` | `python -m promptquery` entry point. | | `cli.py` | Click command and prompt-toolkit REPL. Orchestrates the whole pipeline. | -| `db.py` | psycopg3 connection wrapper. Sets the read-only session. | -| `schema.py` | Dataclasses + the `pg_catalog` queries that introspect them. | +| `db.py` | Postgres and SQLite connection wrappers. Sets the read-only session. | +| `schema.py` | Dataclasses + database-specific introspection adapters. | | `retrieval.py` | Tokenizer, TF-IDF ranker, FK-graph expander. | | `llm.py` | Provider clients (Anthropic, OpenAI), SQL extractor, provider factory. | | `prompts.py` | System prompt template and schema-to-prompt formatter. | @@ -95,7 +95,7 @@ question - **`schema.py` ↔ `prompts.py`** — `format_schema` walks the same dataclasses. Schema additions usually need a prompt update. - **`cli.py` ↔ everything else** — the only file that knows the full pipeline. New stages (query history, post-execution feedback) wire in here. - **`safety.py` ↔ `llm.py`** — `extract_sql` runs first; `validate_select_only` runs second. Together they handle the case where the model returns malformed output. -- **`db.py` ↔ `safety.py`** — two layers, intentionally redundant. Either alone is insufficient; both together make a write impossible. +- **`db.py` ↔ `safety.py`** — two layers, intentionally redundant. Either alone is insufficient; together they keep execution read-only at both the SQL-parser and database-session layers. ## Design bets @@ -109,7 +109,7 @@ Many natural-language questions reference one core entity but require joins thro ### Why two safety layers -`safety.py` is the primary guard. It parses every statement and rejects anything other than a SELECT. The Postgres session-level `default_transaction_read_only = on` is the fallback: if a malicious prompt somehow produces SQL that the validator misclassifies (a parser bug, an unknown construct), Postgres itself refuses the write. +`safety.py` is the primary guard. It parses every statement in the selected database dialect and rejects anything other than a SELECT. The session-level read-only mode is the fallback: Postgres uses `default_transaction_read_only = on`, while SQLite files open with `mode=ro` and use `PRAGMA query_only = ON`. If a malicious prompt somehow produces SQL that the validator misclassifies (a parser bug, an unknown construct), the database itself refuses the write. This redundancy is not paranoia. AI-generated SQL is, by construction, less predictable than human-written SQL. The cost of one accidental `DELETE` is high enough that doubling up is the only sensible default. @@ -121,7 +121,7 @@ This redundancy is not paranoia. AI-generated SQL is, by construction, less pred These are intentionally out of scope for the MVP. They are tracked in the [roadmap](README.md#roadmap). -- MySQL / SQLite support — needs an adapter abstraction first. +- MySQL support — needs an adapter implementation and optional driver decision. - Multi-database sessions in one REPL. - Data visualization (charts, plots). - Query-history persistence between sessions. @@ -136,7 +136,7 @@ Run the test suite: pytest ``` -All tests in v0.1 are pure Python — no live database required. The integration test harness (docker-compose + a real Postgres) is queued for v0.2 alongside the public benchmark suite against Spider / BIRD. +All core tests are pure Python — no live external database required. SQLite adapter tests use temporary local database files. The integration test harness (docker-compose + a real Postgres) is queued for v0.2 alongside the public benchmark suite against Spider / BIRD. The most safety-critical file is `tests/test_safety.py`. Cases there encode "things the validator MUST reject." Add cases when you discover new attack vectors; do not delete cases during refactors. @@ -160,6 +160,7 @@ PromptQuery/ │ ├── safety.py │ └── render.py └── tests/ pytest suite + ├── test_schema_adapters.py ├── test_safety.py └── test_retrieval.py ``` diff --git a/README.md b/README.md index 7d59d92..119b23c 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,13 @@ # PromptQuery -> **Natural-language SQL for production-scale Postgres schemas.** +> **Natural-language SQL for production-scale Postgres and SQLite schemas.** [![PyPI](https://img.shields.io/pypi/v/promptquery.svg)](https://pypi.org/project/promptquery/) [![CI](https://github.com/Cyberfilo/promptquery/actions/workflows/ci.yml/badge.svg)](https://github.com/Cyberfilo/promptquery/actions/workflows/ci.yml) [![License: Apache 2.0](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) [![Python: 3.10+](https://img.shields.io/badge/python-3.10%2B-blue.svg)](pyproject.toml) -PromptQuery is an open-source CLI that lets you query Postgres in plain English — engineered for **real production schemas with hundreds of tables**, not toy demos. It introspects your schema, generates SQL, shows it for confirmation, and runs it read-only. +PromptQuery is an open-source CLI that lets you query Postgres and SQLite in plain English — engineered for **real production schemas with hundreds of tables**, not toy demos. It introspects your schema, generates SQL, shows it for confirmation, and runs it read-only.

PromptQuery turning the plain-English question 'orders over 1000 euros with the customer name and status' into a correct multi-table JOIN — showing the SQL, asking for confirmation, then printing the result rows @@ -55,6 +55,7 @@ export ANTHROPIC_API_KEY=... # Connect and start asking: prq postgresql://localhost/mydb +prq sqlite:///local.db ``` `prq` and `pquery` are short aliases for `promptquery`. All three commands work identically. @@ -67,6 +68,7 @@ prq postgresql://localhost/mydb prq --query "how many users in Italy" postgresql://localhost/mydb # JSON to stdout prq --query "top 10 orders by total" --out csv postgresql://... > out.csv prq --query "..." --out table postgresql://... # rich-formatted table +prq --query "top customers by spend" sqlite:///local.db # SQLite local file ``` Exit codes: `0` success · `1` LLM/connection error · `2` safety-guard rejection · `3` execution error. @@ -133,7 +135,7 @@ question └────────┬──────────┘ │ ▼ - "Run? [y/N]" → execute against a read-only Postgres session + "Run? [y/N]" → execute against a read-only database session ``` See [ARCHITECTURE.md](ARCHITECTURE.md) for the deep dive (file inventory, design bets, the patent-landmine non-goals). @@ -167,8 +169,8 @@ If both are set, Anthropic is preferred. Override either with `--model anthropic PromptQuery has **two independent layers** so a write is impossible, even if one layer fails: -1. **Session-level**: every Postgres session opens with `default_transaction_read_only = on` and a 60-second `statement_timeout`. The database itself refuses non-SELECT operations. -2. **Pre-execution**: every generated query is parsed with `sqlglot` and rejected unless it's a single `SELECT` / `WITH` / `UNION` / `INTERSECT` / `EXCEPT`. The validator also catches CTEs that hide DML (`WITH x AS (DELETE …) SELECT * FROM x`) and dangerous-function calls (`pg_terminate_backend`, `set_config`, `lo_export`, `dblink_exec`). +1. **Session-level**: every Postgres session opens with `default_transaction_read_only = on` and a 60-second `statement_timeout`; SQLite files open with `mode=ro` and enable `PRAGMA query_only = ON`. The database itself refuses non-SELECT operations. +2. **Pre-execution**: every generated query is parsed with `sqlglot` in the selected database dialect and rejected unless it's a single `SELECT` / `WITH` / `UNION` / `INTERSECT` / `EXCEPT`. The validator also catches CTEs that hide DML (`WITH x AS (DELETE …) SELECT * FROM x`) and dangerous-function calls (`pg_terminate_backend`, `set_config`, `lo_export`, `dblink_exec`, `load_extension`). Every query is also shown to you before it runs. Confirm with `y`. @@ -226,7 +228,7 @@ See [`eval/END_TO_END.md`](eval/END_TO_END.md) for the harness internals. ## What PromptQuery does NOT do (yet) - **No writes.** `SELECT` only, by design and by belt-and-suspenders. -- **Postgres only.** MySQL and SQLite are on the v0.4 roadmap. +- **Full multi-dialect coverage.** Postgres remains the reference implementation and SQLite local files are supported; MySQL is still on the roadmap. - **One database at a time.** No multi-DB sessions. - **No data visualisation.** Rows out, that's it. Pipe to `csv` / `jq` / your tool of choice. @@ -236,7 +238,7 @@ See [`eval/END_TO_END.md`](eval/END_TO_END.md) for the harness internals. - **v0.2 (shipped)** — LLM-assisted table selector, stemmed TF-IDF. - **v0.3** — local LLMs (Ollama), schema anonymisation (GDPR-by-default), query-history-as-few-shot. -- **v0.4** — MySQL + SQLite adapters, MCP server mode, public competitor benchmark. +- **v0.4** — MySQL adapter, MCP server mode, public competitor benchmark. --- @@ -255,7 +257,7 @@ python3.12 -m venv .venv .venv/bin/python -m eval.retrieval ``` -37 tests, all pure-Python — no live database or API key required for the core suite. +55 tests, all pure-Python — no live database or API key required for the core suite. --- diff --git a/src/promptquery/cli.py b/src/promptquery/cli.py index 2644c15..8a508fa 100644 --- a/src/promptquery/cli.py +++ b/src/promptquery/cli.py @@ -14,7 +14,7 @@ from rich.console import Console from . import __version__ -from .db import Database +from .db import Database, SQLiteDatabase, make_database from .llm import LLMClient, LLMError, extract_sql, make_client from .prompts import build_system_prompt from .render import render_results, render_sql @@ -81,7 +81,7 @@ def run_question( retriever: TfIdfRetriever, llm: LLMClient, selector_llm: LLMClient | None, - db: Database, + db: Database | SQLiteDatabase, *, top_k: int, select_n: int, @@ -125,7 +125,7 @@ def run_question( return QueryResult(Outcome.EMPTY_SQL, None, [], [], "LLM returned an empty response.") try: - validate_select_only(sql) + validate_select_only(sql, dialect=getattr(db, "dialect", "postgres")) except UnsafeQuery as e: return QueryResult(Outcome.UNSAFE, sql, [], [], str(e)) @@ -210,14 +210,15 @@ def main(dsn: str, model: str | None, selector_model: str | None, query: str | None, out_format: str | None, top_k: int, select_n: int, max_tables: int, no_selector: bool, yes: bool) -> None: - """PromptQuery — natural-language SQL for Postgres. + """PromptQuery — natural-language SQL for Postgres and SQLite. - DSN is a libpq connection string, e.g. postgresql://user:pass@host/db. + DSN is a libpq connection string or sqlite:///path/to.db. Examples: Interactive REPL: promptquery postgresql://localhost/mydb + promptquery sqlite:///local.db One-shot query (machine-friendly JSON to stdout, progress to stderr): promptquery -q "how many users in Italy" postgresql://localhost/mydb @@ -250,7 +251,7 @@ def main(dsn: str, model: str | None, selector_model: str | None, progress.print(f"[dim]Connecting to[/dim] {_redact(dsn)} [dim]...[/dim]") try: - db_ctx = Database(dsn).__enter__() + db_ctx = make_database(dsn).__enter__() except Exception as e: progress.print(f"[red]Connection failed:[/red] {e}") sys.exit(1) @@ -267,7 +268,8 @@ def main(dsn: str, model: str | None, selector_model: str | None, else (" (selector: same)" if selector_llm is not None else " (selector: off)") ) progress.print(f"[green]✓[/green] {len(schema.tables)} tables found " - f"[dim](sql: {llm.name}/{llm.model}{selector_info})[/dim]") + f"[dim](db: {db_ctx.dialect}, sql: " + f"{llm.name}/{llm.model}{selector_info})[/dim]") retriever = TfIdfRetriever(schema) diff --git a/src/promptquery/db.py b/src/promptquery/db.py index da51d64..2eaa088 100644 --- a/src/promptquery/db.py +++ b/src/promptquery/db.py @@ -1,10 +1,16 @@ from __future__ import annotations +import sqlite3 +from urllib.parse import quote + import psycopg from psycopg.rows import dict_row class Database: + dialect = "postgres" + default_schema = "public" + def __init__(self, dsn: str): self.dsn = dsn self.conn: psycopg.Connection | None = None @@ -47,3 +53,87 @@ def __enter__(self) -> "Database": def __exit__(self, exc_type, exc, tb) -> None: self.close() + + +class SQLiteDatabase: + dialect = "sqlite" + default_schema = "main" + + def __init__(self, dsn: str): + self.dsn = dsn + self.path = _sqlite_path_from_dsn(dsn) + self.conn: sqlite3.Connection | None = None + + def connect(self) -> None: + if self.path == ":memory:": + self.conn = sqlite3.connect(self.path) + else: + self.conn = sqlite3.connect( + f"file:{quote(self.path, safe='/')}?mode=ro", + uri=True, + ) + self.conn.row_factory = sqlite3.Row + with self.conn: + self.conn.execute("PRAGMA foreign_keys = ON") + self.conn.execute("PRAGMA query_only = ON") + self.conn.execute("PRAGMA busy_timeout = 60000") + + def close(self) -> None: + if self.conn is not None: + self.conn.close() + self.conn = None + + def _require_conn(self) -> sqlite3.Connection: + if self.conn is None: + raise RuntimeError("Database is not connected") + return self.conn + + def fetch_dicts(self, sql: str) -> list[dict]: + conn = self._require_conn() + cur = conn.execute(sql) + return [dict(row) for row in cur.fetchall()] + + def execute(self, sql: str) -> tuple[list[str], list[tuple]]: + conn = self._require_conn() + cur = conn.execute(sql) + if cur.description is None: + return [], [] + cols = [d[0] for d in cur.description] + rows = [tuple(row) for row in cur.fetchall()] + return cols, rows + + def pragma_dicts(self, name: str, argument: str) -> list[dict]: + conn = self._require_conn() + cur = conn.execute(f"PRAGMA {name}({_quote_sqlite_literal(argument)})") + return [dict(row) for row in cur.fetchall()] + + def __enter__(self) -> "SQLiteDatabase": + self.connect() + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.close() + + +def make_database(dsn: str) -> Database | SQLiteDatabase: + if dsn.startswith("sqlite:///"): + return SQLiteDatabase(dsn) + return Database(dsn) + + +def _sqlite_path_from_dsn(dsn: str) -> str: + if not dsn.startswith("sqlite:///"): + raise ValueError("SQLite DSNs must use sqlite:///path/to.db") + + path = dsn[len("sqlite:///"):] + if not path: + raise ValueError("SQLite DSN is missing a database path") + if path == ":memory:": + return path + if dsn.startswith("sqlite:////"): + return "/" + dsn[len("sqlite:////"):] + return path + + +def _quote_sqlite_literal(value: str) -> str: + return "'" + value.replace("'", "''") + "'" diff --git a/src/promptquery/safety.py b/src/promptquery/safety.py index a166768..5fffa3e 100644 --- a/src/promptquery/safety.py +++ b/src/promptquery/safety.py @@ -40,16 +40,17 @@ class UnsafeQuery(Exception): "lo_export", "dblink_exec", "set_config", + "load_extension", } -def validate_select_only(sql: str) -> None: +def validate_select_only(sql: str, *, dialect: str = "postgres") -> None: sql = (sql or "").strip() if not sql: raise UnsafeQuery("empty SQL") try: - statements = sqlglot.parse(sql, read="postgres") + statements = sqlglot.parse(sql, read=dialect) except Exception as e: raise UnsafeQuery(f"could not parse SQL: {e}") from e diff --git a/src/promptquery/schema.py b/src/promptquery/schema.py index 9258611..83c485f 100644 --- a/src/promptquery/schema.py +++ b/src/promptquery/schema.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .db import Database + from .db import Database, SQLiteDatabase @dataclass(frozen=True) @@ -67,7 +67,7 @@ class Table: @property def qualified_name(self) -> str: - if self.schema == "public": + if self.schema in {"public", "main"}: return self.name return f"{self.schema}.{self.name}" @@ -166,7 +166,13 @@ def from_dict(cls, data: dict) -> "Schema": """ -def introspect(db: "Database") -> Schema: +def introspect(db: "Database | SQLiteDatabase") -> Schema: + if getattr(db, "dialect", "postgres") == "sqlite": + return _introspect_sqlite(db) + return _introspect_postgres(db) + + +def _introspect_postgres(db: "Database") -> Schema: tables_rows = db.fetch_dicts(INTROSPECT_TABLES) columns_rows = db.fetch_dicts(INTROSPECT_COLUMNS) fks_rows = db.fetch_dicts(INTROSPECT_FKS) @@ -205,3 +211,48 @@ def introspect(db: "Database") -> Schema: )) return Schema(tables=list(tables.values())) + + +def _introspect_sqlite(db: "SQLiteDatabase") -> Schema: + tables_rows = db.fetch_dicts(""" + SELECT 'main' AS schema, + name, + NULL AS comment, + type + FROM sqlite_schema + WHERE type IN ('table', 'view') + AND name NOT LIKE 'sqlite_%' + ORDER BY name + """) + + tables: dict[tuple[str, str], Table] = {} + for row in tables_rows: + key = (row["schema"], row["name"]) + tables[key] = Table( + schema=row["schema"], + name=row["name"], + comment=row["comment"], + ) + + for table in tables.values(): + for row in db.pragma_dicts("table_xinfo", table.name): + if row["hidden"] == 1: + continue + table.columns.append(Column( + name=row["name"], + data_type=row["type"] or "UNKNOWN", + nullable=not bool(row["notnull"]) and not bool(row["pk"]), + is_primary_key=bool(row["pk"]), + )) + + for table in tables.values(): + fks_rows = db.pragma_dicts("foreign_key_list", table.name) + for row in sorted(fks_rows, key=lambda r: (r["id"], r["seq"])): + table.foreign_keys.append(ForeignKey( + column=row["from"], + referenced_schema="main", + referenced_table=row["table"], + referenced_column=row["to"], + )) + + return Schema(tables=list(tables.values())) diff --git a/tests/test_safety.py b/tests/test_safety.py index 17d0f38..1c4d826 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -56,3 +56,17 @@ def test_rejects_cte_with_dml(): validate_select_only( "WITH deleted AS (DELETE FROM users RETURNING id) SELECT * FROM deleted" ) + + +def test_accepts_sqlite_dialect_queries(): + validate_select_only( + 'SELECT json_extract(payload, "$.city") FROM events LIMIT 10', + dialect="sqlite", + ) + + +def test_sqlite_dialect_still_rejects_mutating_statements(): + with pytest.raises(UnsafeQuery): + validate_select_only("ATTACH DATABASE 'other.db' AS other", dialect="sqlite") + with pytest.raises(UnsafeQuery): + validate_select_only("SELECT load_extension('mod_spatialite')", dialect="sqlite") diff --git a/tests/test_schema_adapters.py b/tests/test_schema_adapters.py new file mode 100644 index 0000000..da0258e --- /dev/null +++ b/tests/test_schema_adapters.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import sqlite3 + +import pytest + +from promptquery.db import SQLiteDatabase, make_database +from promptquery.schema import introspect + + +class _FakePostgresDB: + dialect = "postgres" + + def fetch_dicts(self, sql: str) -> list[dict]: + if "FROM pg_class" in sql: + return [ + {"schema": "public", "name": "users", "comment": "accounts"}, + {"schema": "sales", "name": "orders", "comment": None}, + ] + if "FROM pg_attribute" in sql: + return [ + { + "schema": "public", + "table_name": "users", + "name": "id", + "data_type": "bigint", + "nullable": False, + "is_primary_key": True, + }, + { + "schema": "sales", + "table_name": "orders", + "name": "user_id", + "data_type": "bigint", + "nullable": False, + "is_primary_key": False, + }, + ] + if "FROM pg_constraint" in sql: + return [ + { + "schema": "sales", + "table_name": "orders", + "column_name": "user_id", + "referenced_schema": "public", + "referenced_table": "users", + "referenced_column": "id", + }, + ] + raise AssertionError(f"unexpected query: {sql}") + + +def test_postgres_introspection_preserves_existing_row_mapping(): + schema = introspect(_FakePostgresDB()) + + assert [table.qualified_name for table in schema.tables] == ["users", "sales.orders"] + users = schema.tables[0] + orders = schema.tables[1] + + assert users.comment == "accounts" + assert users.columns[0].name == "id" + assert users.columns[0].is_primary_key is True + assert orders.foreign_keys[0].referenced_table == "users" + + +def test_make_database_selects_sqlite_for_sqlite_dsn(tmp_path): + db_file = tmp_path / "shop.db" + + db = make_database(f"sqlite:///{db_file}") + + assert isinstance(db, SQLiteDatabase) + assert db.path == str(db_file) + + +def test_sqlite_database_does_not_create_missing_file(tmp_path): + db_file = tmp_path / "missing.db" + db = SQLiteDatabase(f"sqlite:///{db_file}") + + with pytest.raises(sqlite3.OperationalError): + db.connect() + + assert not db_file.exists() + + +def test_sqlite_introspection_reads_tables_views_columns_and_foreign_keys(tmp_path): + db_file = tmp_path / "shop.db" + with sqlite3.connect(db_file) as conn: + conn.execute("PRAGMA foreign_keys = ON") + conn.execute(""" + CREATE TABLE authors ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL + ) + """) + conn.execute(""" + CREATE TABLE books ( + id INTEGER PRIMARY KEY, + author_id INTEGER NOT NULL REFERENCES authors(id), + title TEXT, + slug TEXT GENERATED ALWAYS AS (lower(title)) VIRTUAL + ) + """) + conn.execute("CREATE VIEW book_titles AS SELECT title FROM books") + + with SQLiteDatabase(f"sqlite:///{db_file}") as db: + schema = introspect(db) + with pytest.raises(sqlite3.OperationalError, match="readonly"): + db.execute("INSERT INTO authors (name) VALUES ('blocked')") + + tables = {table.name: table for table in schema.tables} + assert set(tables) == {"authors", "book_titles", "books"} + assert tables["authors"].qualified_name == "authors" + + author_columns = {column.name: column for column in tables["authors"].columns} + assert author_columns["id"].is_primary_key is True + assert author_columns["id"].nullable is False + assert author_columns["name"].data_type == "TEXT" + assert author_columns["name"].nullable is False + + book_columns = {column.name: column for column in tables["books"].columns} + assert book_columns["slug"].data_type == "TEXT" + assert book_columns["title"].nullable is True + + assert tables["books"].foreign_keys[0].column == "author_id" + assert tables["books"].foreign_keys[0].referenced_schema == "main" + assert tables["books"].foreign_keys[0].referenced_table == "authors" + assert tables["books"].foreign_keys[0].referenced_column == "id"