From cbe6f83afa75cdb3a809e1b9b93af501fa2f92ed Mon Sep 17 00:00:00 2001 From: Jose Agustin Puente Date: Tue, 2 Jun 2026 10:08:56 +0200 Subject: [PATCH 1/4] =?UTF-8?q?fix(query+ingest):=20value-anchored=20NL?= =?UTF-8?q?=E2=86=92SQL=20grounding,=20hierarchy=20resolution,=20pipeline?= =?UTF-8?q?=20robustness?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Surface column values into the grounding/generation prompts and harden the query + ingestion pipelines so the NL→SQL agents stop guessing literals and the control flow stops reporting wrong-but-empty results as confident answers. Grounding / value anchoring: - profile: store the full distinct value set for low-cardinality columns (was top-5) - retrieval: value_fingerprint() on hits; render a "Column value catalogue" into the grounding and generation prompts; instructions to copy WHERE/CASE literals verbatim, handle tall/EAV layouts and scaled-duplicate measures - structurally detect self-referencing hierarchy columns (value containment, with a row-wise discriminator) and teach "a person's team/reports/structure" -> filter the hierarchy column Control flow / governance: - treat 0-row / degenerate aggregate results as suspicious -> clarification + confidence cap - gate auto-learning on row_count > 0 (stop poisoning the example store) - prefer a candidate that passes the firewall and returns non-empty rows over candidates[0] - exclude CTE names from AST table_refs (firewall no longer rejects valid CTE SQL) - snapshot-scope retrieval; honor and persist conversation snapshot pins - unify /stream, :explain, :validate on the rendered prompts + firewall/scope guards - thread a re-firewalled extra_filter through the semantic-metric bind() Ingestion robustness / determinism: - header-quality scoring (skip numeric spacer rows) + stable leading-blank handling - temperature=0 for the naming/describe agents (cuts re-ingest schema churn ~94%) - rewrite Parquet column names on the no-key fallback (fixes profile/sample Binder errors) - index column values into embeddings + content_tsv - warn when the cross-encoder reranker degrades to a no-op --- src/flyquery/core/agents/builder.py | 10 + .../core/agents/column_name_proposer_agent.py | 3 + src/flyquery/core/agents/describe_agent.py | 3 + .../core/agents/rename_detection_agent.py | 3 + .../core/services/examples/auto_learner.py | 8 + .../core/services/execution/ast_classifier.py | 21 +- .../core/services/execution/table_resolver.py | 36 ++- .../ingestion/readers/excel_reader.py | 111 +++++++- .../core/services/ingestion/stages/embed.py | 60 +++- .../core/services/ingestion/stages/parse.py | 11 + .../core/services/ingestion/stages/profile.py | 168 ++++++++++- .../core/services/query/query_service.py | 264 +++++++++++++++++- .../core/services/retrieval/reranker.py | 19 +- .../core/services/retrieval/search_index.py | 150 ++++++++-- .../web/controllers/query_controller.py | 171 +++++++++--- 15 files changed, 955 insertions(+), 83 deletions(-) diff --git a/src/flyquery/core/agents/builder.py b/src/flyquery/core/agents/builder.py index 3044bfb..a15e048 100644 --- a/src/flyquery/core/agents/builder.py +++ b/src/flyquery/core/agents/builder.py @@ -48,6 +48,7 @@ def build_agent( instructions: str, settings: FlyquerySettings, max_output_tokens: int | None = None, + temperature: float | None = None, extra_settings: dict[str, Any] | None = None, ) -> Any: """Construct a :class:`FireflyAgent` with the standard knobs. @@ -67,6 +68,13 @@ def build_agent( requiring callers to plumb the env var themselves. max_output_tokens: Optional override for this specific call. ``None`` falls back to ``settings.agent_max_output_tokens``. + temperature: Optional sampling temperature folded into + ``model_settings``. ``None`` (the default) leaves the + provider default untouched -- grounding / generation / + critic rely on that diversity (generation samples N + candidates). Naming / description stages pass ``0.0`` so + near-identical re-ingests yield identical names and avoid + schema-change churn. extra_settings: Optional extra ``model_settings`` entries. Caller-provided keys WIN on conflict so a stage can cap below the global budget (e.g. a 1-token classifier). @@ -83,6 +91,8 @@ def build_agent( resolved_max = resolve_max_output_tokens(settings, override=max_output_tokens) model_settings: dict[str, Any] = {"max_tokens": resolved_max} + if temperature is not None: + model_settings["temperature"] = temperature if extra_settings: # Caller-provided settings win on conflict -- a stage can cap # itself below the default by passing ``max_tokens=128`` in diff --git a/src/flyquery/core/agents/column_name_proposer_agent.py b/src/flyquery/core/agents/column_name_proposer_agent.py index 3eed345..e317931 100644 --- a/src/flyquery/core/agents/column_name_proposer_agent.py +++ b/src/flyquery/core/agents/column_name_proposer_agent.py @@ -116,6 +116,9 @@ def build_column_name_proposer_agent(settings): output_type=ProposedColumnNames, instructions=prompt.instructions, settings=settings, + # Deterministic naming: identical re-ingests must yield identical + # column names, otherwise reconcile sees phantom schema churn. + temperature=0.0, ) diff --git a/src/flyquery/core/agents/describe_agent.py b/src/flyquery/core/agents/describe_agent.py index ac6870e..4333e38 100644 --- a/src/flyquery/core/agents/describe_agent.py +++ b/src/flyquery/core/agents/describe_agent.py @@ -140,4 +140,7 @@ def build_describe_agent(settings): output_type=DescribedObjects, instructions=prompt.instructions, settings=settings, + # Deterministic descriptions/semantic types: identical columns on + # re-ingest must produce identical metadata, no schema-change churn. + temperature=0.0, ) diff --git a/src/flyquery/core/agents/rename_detection_agent.py b/src/flyquery/core/agents/rename_detection_agent.py index 2c1c4a9..dd951a3 100644 --- a/src/flyquery/core/agents/rename_detection_agent.py +++ b/src/flyquery/core/agents/rename_detection_agent.py @@ -62,4 +62,7 @@ def build_rename_detection_agent(settings): settings=settings, # Rename detection is a short task; cap output tokens tightly max_output_tokens=2048, + # Deterministic: the same removed/candidate pair must always + # resolve the same way so re-ingests don't flip-flop renames. + temperature=0.0, ) diff --git a/src/flyquery/core/services/examples/auto_learner.py b/src/flyquery/core/services/examples/auto_learner.py index 48101da..1219058 100644 --- a/src/flyquery/core/services/examples/auto_learner.py +++ b/src/flyquery/core/services/examples/auto_learner.py @@ -28,6 +28,7 @@ class AutoLearner: Skips when: - ``retries > 0`` (query required critic refinement) - PII findings were detected in the result + - the query returned no rows (``row_count`` is 0, when provided) Called by QueryService (Phase D) after a successful execution. """ @@ -46,6 +47,7 @@ async def maybe_propose( retries: int, pii_findings: list[Any], query_id: uuid.UUID, + row_count: int | None = None, ) -> None: """Insert a flyquery_examples row when all criteria pass. @@ -57,11 +59,17 @@ async def maybe_propose( :param retries: number of critic refinement loops (must be 0 to propose) :param pii_findings: any PII signals detected (must be empty to propose) :param query_id: UUID of the parent query record + :param row_count: number of rows the query returned; when provided it + must be > 0 to propose (a valid-but-wrong query returning 0 rows + would otherwise poison grounding). When ``None`` the row gate is + skipped to preserve behaviour for callers that do not pass it. """ if retries > 0: return if pii_findings: return + if row_count is not None and row_count <= 0: + return await self._service.create( tenant_id, workspace_id, diff --git a/src/flyquery/core/services/execution/ast_classifier.py b/src/flyquery/core/services/execution/ast_classifier.py index a5f9cd7..dfcd083 100644 --- a/src/flyquery/core/services/execution/ast_classifier.py +++ b/src/flyquery/core/services/execution/ast_classifier.py @@ -76,8 +76,25 @@ def classify(self, sql: str) -> AstClassification: # pyright does not unify with the public ``Expression`` base below. kind = self._kind(stmt) # pyright: ignore[reportArgumentType] - # Collect table refs — skip anonymous subquery aliases - tables = tuple(sorted({t.name for t in stmt.find_all(sqlglot.expressions.Table) if t.name})) + # Collect table refs — skip anonymous subquery aliases AND + # CTE-defined names. sqlglot represents a reference to a CTE + # (``FROM base`` where ``WITH base AS (...)``) as an ``exp.Table`` + # node, so without this filter the CTE alias leaks into + # ``table_refs``; the downstream bad-tables guard then flags it + # as a non-existent table and the (otherwise valid) query is + # rejected — see QueryService bad-tables set-difference. + cte_names = { + cte.alias_or_name for cte in stmt.find_all(sqlglot.expressions.CTE) if cte.alias_or_name + } + tables = tuple( + sorted( + { + t.name + for t in stmt.find_all(sqlglot.expressions.Table) + if t.name and t.name not in cte_names + } + ) + ) columns = tuple(sorted({c.name for c in stmt.find_all(sqlglot.expressions.Column) if c.name})) has_subquery = bool(list(stmt.find_all(sqlglot.expressions.Subquery))) diff --git a/src/flyquery/core/services/execution/table_resolver.py b/src/flyquery/core/services/execution/table_resolver.py index 3bcd6b5..0c18785 100644 --- a/src/flyquery/core/services/execution/table_resolver.py +++ b/src/flyquery/core/services/execution/table_resolver.py @@ -32,6 +32,7 @@ from __future__ import annotations +import json import uuid import sqlalchemy as sa @@ -54,6 +55,7 @@ async def resolve( dataset_id: uuid.UUID, table_names: list[str], object_store_base: str | None = None, + pins: dict[str, str] | None = None, ) -> dict[str, str]: """Return a mapping of table name → absolute parquet path. @@ -65,6 +67,11 @@ async def resolve( :param dataset_id: dataset to scope the lookup :param table_names: unqualified table names from the AST :param object_store_base: override for ``settings.object_store_base`` + :param pins: optional ``{table_name: snapshot_id}`` — a follow-up + drill-down turn pins each table to the snapshot it resolved to + on the first turn, so a mid-conversation re-ingest does not + silently switch the answer to a newer schema. Unpinned tables + fall back to ``current_snapshot_id``. :return: ``{name: path}`` dict for all resolvable tables """ if not table_names: @@ -77,12 +84,16 @@ async def resolve( SELECT t.name, ss.parquet_object_key FROM flyquery_tables t JOIN flyquery_schema_snapshots ss - ON ss.id = t.current_snapshot_id + ON ss.table_id = t.id + AND ss.id = COALESCE( + (CAST(:pins AS jsonb) ->> t.name)::uuid, + t.current_snapshot_id + ) WHERE t.dataset_id = :ds AND t.name = ANY(:names) AND t.is_active = true """), - {"ds": dataset_id, "names": list(table_names)}, + {"ds": dataset_id, "names": list(table_names), "pins": json.dumps(pins or {})}, ) out: dict[str, str] = {} @@ -90,3 +101,24 @@ async def resolve( key: str = r["parquet_object_key"] out[r["name"]] = f"{base}/{key}" return out + + async def current_snapshots( + self, dataset_id: uuid.UUID, table_names: list[str] + ) -> dict[str, str]: + """Return ``{table_name: current_snapshot_id}`` for the given tables. + + Used to record THIS turn's snapshot pins so a later drill-down turn + can reproduce the exact schema version it answered against. + """ + if not table_names: + return {} + rows = await self._session.execute( + sa.text(""" + SELECT name, current_snapshot_id + FROM flyquery_tables + WHERE dataset_id = :ds AND name = ANY(:names) AND is_active = true + AND current_snapshot_id IS NOT NULL + """), + {"ds": dataset_id, "names": list(table_names)}, + ) + return {r["name"]: str(r["current_snapshot_id"]) for r in rows.mappings()} diff --git a/src/flyquery/core/services/ingestion/readers/excel_reader.py b/src/flyquery/core/services/ingestion/readers/excel_reader.py index 2423222..c909fac 100644 --- a/src/flyquery/core/services/ingestion/readers/excel_reader.py +++ b/src/flyquery/core/services/ingestion/readers/excel_reader.py @@ -73,6 +73,79 @@ # Heuristic constants _MIN_HEADER_NON_EMPTY = 2 # a row needs >= 2 non-empty cells to be a header _SECTION_BREAK_EMPTY_ROWS = 2 # >= 2 consecutive empty rows close a section +# When the first header candidate looks like a numeric/spacer row, look this +# many rows ahead for a clearly-better string-label row before settling. +_HEADER_LOOKAHEAD = 3 + + +def _normalise_leading_blanks(rows: list[list[Any]]) -> list[list[Any]]: + """Drop leading rows that are *entirely* empty. + + calamine's ``to_python(skip_empty_area=True)`` trims the used-range + bounding box, but the exact number of leading blank rows it keeps can + differ between near-identical re-ingests (HDR-UNSTABLE) -- which flips + every section's absolute row index and causes column-name churn. We + normalise here by consistently removing fully-empty leading rows so the + same logical sheet yields the same header index. This MUST be applied + identically in ``_enumerate_sync`` (where indices are computed) and + ``_materialise_sync`` (where they are sliced) for the indices to line up. + """ + start = 0 + n = len(rows) + while start < n and all(c in ("", None) for c in rows[start]): + start += 1 + # Avoid a needless copy when nothing was trimmed. + return rows if start == 0 else rows[start:] + + +def _looks_like_label_row(row: list[Any]) -> bool: + """True when a row's populated cells are predominantly non-empty STRINGS. + + Header rows hold labels (text); spacer/index rows hold bare numbers + (1, 2, 3, ...). We require a strict string majority so genuinely numeric, + period, or date headers are not misclassified as data. + """ + non_empty = [c for c in row if c not in ("", None)] + if not non_empty: + return False + str_cells = sum(1 for c in non_empty if isinstance(c, str) and c.strip() != "") + return str_cells * 2 > len(non_empty) + + +def _is_numeric_spacer_row(row: list[Any]) -> bool: + """True when a header candidate looks like a numeric spacer, not labels. + + A row is a spacer when every populated cell is numeric (int/float, or a + numeric-looking string) and none is a real text label. This includes the + classic contiguous ``1..k`` column-numbering run Excel exports sometimes + inject above the real header band. A row with any genuine string label is + never a spacer. + """ + non_empty = [c for c in row if c not in ("", None)] + if len(non_empty) < _MIN_HEADER_NON_EMPTY: + return False + + def _as_number(cell: Any) -> float | None: + if isinstance(cell, bool): + return None + if isinstance(cell, (int, float)): + return float(cell) + if isinstance(cell, str): + try: + return float(cell.strip()) + except ValueError: + return None + return None + + numbers = [_as_number(c) for c in non_empty] + if any(num is None for num in numbers): + # Some populated cell is a non-numeric string -> not a spacer. + return False + + # Every populated cell is numeric and none is a text label. This covers + # both the all-numeric case and, as a strict subset, the contiguous + # ``1..k`` column-numbering run -- both are spacers, never label rows. + return True class ExcelReader: @@ -146,9 +219,34 @@ def _extract_sections(rows: list[list[Any]]) -> list[dict[str, Any]]: i += 1 continue - # Multi-cell row = section header + # Multi-cell row = section header candidate. + # + # A naive reader takes the FIRST >=2-non-empty row as the header. + # But dashboard exports sometimes inject a numeric spacer / column- + # numbering row (e.g. ``1 2 3 4``) just above the real label band + # (HDR-MULTIROW). If we treat that spacer as the header, the real + # labels become data and columns get opaque positional names. So + # when this candidate looks like a numeric/contiguous spacer, peek + # a small window ahead and prefer the first following row that is + # clearly a string-label row. We only skip the candidate when such + # a better row exists -- genuinely numeric / period / date headers + # (no string-label row just below) are left untouched, as are + # single-row sheets. header_idx = i - j = i + 1 + if _is_numeric_spacer_row(row) and not _looks_like_label_row(row): + look_end = min(i + 1 + _HEADER_LOOKAHEAD, n) + for la in range(i + 1, look_end): + la_ne = [c for c in rows[la] if c not in ("", None)] + if len(la_ne) == 0: + # A blank row before any label row means the spacer is + # really the last populated row -- stop looking ahead. + break + if len(la_ne) >= _MIN_HEADER_NON_EMPTY and _looks_like_label_row(rows[la]): + # Found a better string-label header just below; treat + # the skipped numeric/title rows as pre-header. + header_idx = la + break + j = header_idx + 1 consecutive_empty = 0 data_end = j # exclusive while j < n: @@ -208,7 +306,10 @@ def _enumerate_sync(source_path: str, rules: TableExtractionRules) -> list[Propo if allow is not None and sheet_name not in allow: continue sheet = wb.get_sheet_by_name(sheet_name) - rows = sheet.to_python(skip_empty_area=True) + # Normalise leading fully-empty rows so section indices are stable + # across re-ingests (HDR-UNSTABLE). The SAME normalisation runs in + # ``_materialise_sync`` so the stored indices slice the same rows. + rows = _normalise_leading_blanks(sheet.to_python(skip_empty_area=True)) if not rows: continue @@ -295,7 +396,9 @@ def _materialise_sync( ) wb = CalamineWorkbook.from_path(source_path) sheet = wb.get_sheet_by_name(sheet_name) - rows = sheet.to_python(skip_empty_area=True) + # Apply the SAME leading-blank normalisation used at enumerate time so + # the stored section indices slice the intended rows (HDR-UNSTABLE). + rows = _normalise_leading_blanks(sheet.to_python(skip_empty_area=True)) if sec_start is not None and sec_end is not None: # Section-encoded path -- slice precisely. diff --git a/src/flyquery/core/services/ingestion/stages/embed.py b/src/flyquery/core/services/ingestion/stages/embed.py index cd6e65d..d3d9211 100644 --- a/src/flyquery/core/services/ingestion/stages/embed.py +++ b/src/flyquery/core/services/ingestion/stages/embed.py @@ -94,8 +94,14 @@ async def run_embed( async with session_factory() as s: result = await s.execute( sa.text( + # profile_json/sample_values_json are pulled in so the embed + # text (and thus content_tsv) covers the column's actual + # VALUES, making value-bearing columns retrievable by + # BM25/vector -- e.g. a question for "Total Revenue" finds + # the column whose distinct values include it. """ - SELECT id, qualified_name, data_type, description, synonyms_json + SELECT id, qualified_name, data_type, description, synonyms_json, + profile_json, sample_values_json FROM flyquery_schema_objects WHERE snapshot_id = :sid AND tenant_id = :tenant ORDER BY kind, qualified_name @@ -152,9 +158,61 @@ def _build_embed_text(row: dict) -> str: flat = list(synonyms.values()) if flat: parts.append("Synonyms: " + ", ".join(str(s) for s in flat)) + values = _render_values(row) + if values: + parts.append(values) return "\n".join(p for p in parts if p) +def _render_values(row: dict, *, max_chars: int = 300) -> str: + """Compact rendering of a column's actual VALUES for the embed corpus. + + Indexing the values (not just name + description) is what lets a + question like "Total Revenue" retrieve the column whose distinct set + contains that literal. + + PII safety: prefer the PII-gated ``sample_values_json`` -- the pii_tag + stage wipes it to ``[]`` when a redact/reject policy fires, so an empty + list here means "do not surface raw samples". When no gated samples are + present we fall back to ``profile_json.top_values``, the stored distinct + set for low-cardinality columns (aggregate / low-cardinality, so lower + PII risk). The result is capped to ``max_chars`` either way. + """ + seen: set[str] = set() + uniq: list[str] = [] + + # Preferred source: PII-gated samples (empty list = intentionally wiped). + samples = row.get("sample_values_json") + if isinstance(samples, list) and samples: + for v in samples: + if v is None: + continue + s = str(v) + if s not in seen: + seen.add(s) + uniq.append(s) + else: + # Fallback: stored distinct set (aggregate, low-cardinality). + prof = row.get("profile_json") + top_values = (prof or {}).get("top_values") if isinstance(prof, dict) else None + for tv in top_values or []: + v = tv.get("value") if isinstance(tv, dict) else tv + if v is None: + continue + s = str(v) + if s not in seen: + seen.add(s) + uniq.append(s) + + if not uniq: + return "" + body = " | ".join(uniq) + if len(body) > max_chars: + # Trim on a value boundary so we never emit a half-truncated literal. + body = body[:max_chars].rsplit("|", 1)[0].strip() + " …" + return f"Values: {body}" + + async def _update_object( *, object_id: uuid.UUID, diff --git a/src/flyquery/core/services/ingestion/stages/parse.py b/src/flyquery/core/services/ingestion/stages/parse.py index d54a8c2..fea6eba 100644 --- a/src/flyquery/core/services/ingestion/stages/parse.py +++ b/src/flyquery/core/services/ingestion/stages/parse.py @@ -182,6 +182,17 @@ async def _propose_meaningful_column_names( "stage=parse rename_skipped reason=no_api_key fallback_prefix=%s", section_prefix, ) + if fallback == current_names: + return mat_result + # Rewrite the physical Parquet so its column names match the + # section-prefixed fallback we record in the MaterialiseResult -- + # otherwise the persisted schema_objects names diverge from the + # Parquet header and downstream DuckDB stages hit Binder Errors. + await _rename_parquet_columns( + parquet_path=parquet_path, + current_columns=current_names, + proposed_columns=fallback, + ) return _rebuild_mat_result(mat_result, fallback) try: diff --git a/src/flyquery/core/services/ingestion/stages/profile.py b/src/flyquery/core/services/ingestion/stages/profile.py index 1c6be8b..301c724 100644 --- a/src/flyquery/core/services/ingestion/stages/profile.py +++ b/src/flyquery/core/services/ingestion/stages/profile.py @@ -18,7 +18,9 @@ - null_fraction - approx_count_distinct (via approx_count_distinct()) - min / max (numeric + temporal columns only) - - top 5 values (low-cardinality only: distinct_estimate ≤ 100) + - full distinct value set, capped at 100 (low-cardinality only: + distinct_estimate ≤ 100) -- surfaced into the NL→SQL prompts so + filter/CASE literals are copied verbatim from real values Skips the whole column if the snapshot's n_rows_actual exceeds FLYQUERY_PROFILE_ROW_THRESHOLD (default 10M rows). @@ -31,6 +33,7 @@ import asyncio import json import logging +import re import uuid from typing import Any @@ -39,6 +42,49 @@ logger = logging.getLogger(__name__) +# Generic, language-agnostic markers for pre-aggregated subtotal / rollup +# values that can coexist with detail rows in a categorical dimension. +# This is a CURATED set of common total markers (English + Spanish), NOT +# tied to any specific dataset/column name. A value matching this is flagged +# as a likely subtotal so the query prompt can avoid mixing subtotal + detail +# rows (summing across all rows double-counts; filtering to it drops detail). +# Token boundary is a string edge or a non-alphanumeric separator (space, +# underscore, hyphen, etc.) so labels like "Total_Department" are caught while +# words that merely embed a marker (e.g. "allocation", "North America") are not. +_SUBTOTAL_MARKERS = ( + r"sub[ _-]?total|grand[ _-]?total|gran[ _-]?total|totals?|totales|" + r"all|todos|todas|suma|consolidad[oa]s?|overall" +) +# A value is flagged ONLY when it is essentially a total *marker by itself* +# (the whole value is the marker), OR a machine-generated pivot label where the +# marker is joined to another token by ``_``/``-`` (e.g. ``Total_Department``, +# ``Department-Total``). This deliberately does NOT match space-separated +# natural-language line items like ``Total Revenue`` / ``Total Nexium`` -- those +# are legitimate measure values the agent must keep, not pre-aggregated rows. +_SUBTOTAL_EXACT = re.compile(rf"^(?:{_SUBTOTAL_MARKERS})$", re.IGNORECASE) +_SUBTOTAL_JOINED = re.compile( + rf"(?:^|[_-])(?:{_SUBTOTAL_MARKERS})(?=[_-])|(?<=[_-])(?:{_SUBTOTAL_MARKERS})$", + re.IGNORECASE, +) +_BLANK_PLACEHOLDERS = {"(blank)", "(empty)", "(null)", "(en blanco)", "(vacío)", "(vacio)"} + + +def _looks_like_subtotal(value: str) -> bool: + """Heuristic: True only for unambiguous total/rollup pivot labels. + + Conservative + side-effect free: matches a standalone total marker, a + separator-joined pivot label (``Total_Department``), or a blank-ish + placeholder -- but NOT space-separated natural-language values such as + ``Total Revenue``. Consumed downstream as a prompt hint only. + """ + if value is None: + return False + stripped = value.strip() + if stripped == "" or stripped.lower() in _BLANK_PLACEHOLDERS: + return True + return bool(_SUBTOTAL_EXACT.match(stripped) or _SUBTOTAL_JOINED.search(stripped)) + + # Data-type compatibility groups _NUMERIC_TYPES = frozenset( { @@ -100,6 +146,9 @@ async def run_profile( columns = await _load_columns(tenant_id, snapshot_id, session_factory) profiled = 0 + # Collect profiles first so a second pass can detect self-referencing + # hierarchy columns (manager->report) before persisting. + computed: dict[str, tuple[uuid.UUID, str, dict[str, Any]]] = {} for col in columns: col_id: uuid.UUID = col["id"] col_name: str = col["qualified_name"].rsplit(".", 1)[-1] @@ -109,8 +158,23 @@ async def run_profile( _profile_column_sync, parquet_key, col_name, data_type, n_rows_actual ) if profile is not None: - await _persist_profile(col_id, profile, tenant_id, session_factory) - profiled += 1 + computed[col_name] = (col_id, data_type, profile) + + # Self-reference detection: annotate columns whose values are mostly + # contained in a higher-cardinality "entity" column of the same table + # (e.g. a manager/owner column whose values are people from the employee + # column). Purely structural (value containment) -- no name/vocabulary + # heuristics -- so it generalises to any self-referencing hierarchy. + try: + refs = await asyncio.to_thread(_detect_self_references, parquet_key, computed, n_rows_actual) + for ref_col, entity_col in refs.items(): + computed[ref_col][2]["references_column"] = entity_col + except Exception as exc: # noqa: BLE001 -- detection is best-effort + logger.warning("stage=profile self-reference detection failed snapshot=%s err=%s", snapshot_id, exc) + + for col_id, _dt, profile in computed.values(): + await _persist_profile(col_id, profile, tenant_id, session_factory) + profiled += 1 logger.info( "stage=profile snapshot_id=%s columns_profiled=%d", @@ -120,6 +184,79 @@ async def run_profile( return {"snapshot_id": str(snapshot_id), "columns_profiled": profiled} +def _detect_self_references( + parquet_key: str, + computed: dict[str, tuple[uuid.UUID, str, dict[str, Any]]], + n_rows: int, +) -> dict[str, str]: + """Find columns that reference a higher-cardinality entity column. + + A reference (foreign-key-like, including a self-referencing org + hierarchy) is detected purely structurally: a text column ``R`` whose + distinct values are mostly a SUBSET of another, higher-cardinality text + column ``E`` in the same table. ``R`` must be on the many-to-one side + (meaningfully fewer distinct values than ``E``) so two near-duplicate + name columns are not flagged as a hierarchy. + + Returns ``{ref_col: entity_col}``. No names/keywords are inspected -- + this generalises to any dataset's manager/owner/parent columns. + """ + import duckdb + + text_cols = { + name: prof["distinct_estimate"] + for name, (_cid, dt, prof) in computed.items() + if not _is_numeric(dt) and not _is_temporal(dt) and prof.get("distinct_estimate") + } + if len(text_cols) < 2: + return {} + # Entity columns: high-cardinality name/id columns (the "one" side). + entity_cols = [c for c, d in text_cols.items() if d >= max(8, 0.2 * n_rows)] + if not entity_cols: + return {} + + refs: dict[str, str] = {} + con = duckdb.connect() + try: + for ref_col, d_ref in text_cols.items(): + if d_ref < 3: + continue + for ent_col in entity_cols: + if ent_col == ref_col or d_ref > 0.7 * text_cols[ent_col]: + continue # ref must be the many-to-one (smaller) side + s_ref = '"' + ref_col.replace('"', '""') + '"' + s_ent = '"' + ent_col.replace('"', '""') + '"' + try: + row = con.execute( + f"SELECT count(DISTINCT {s_ref}), " + f"count(DISTINCT CASE WHEN {s_ref} IN " + f"(SELECT {s_ent} FROM read_parquet(?)) THEN {s_ref} END) " + f"FROM read_parquet(?) WHERE {s_ref} IS NOT NULL AND {s_ref} <> ''", + [parquet_key, parquet_key], + ).fetchone() + except Exception: # noqa: BLE001 -- skip uncomparable columns + continue + if row and row[0] and (row[1] / row[0]) >= 0.6: + # Discriminate a genuine cross-reference (the value is a + # DIFFERENT entity than the row's own -- a manager, owner, + # parent) from a row-wise DUPLICATE of the entity column (the + # SAME entity copied, e.g. a second name column). On a + # duplicate, ref == entity on most rows; on a real reference + # they differ. Skip duplicates. + eq = con.execute( + f"SELECT avg(CASE WHEN {s_ref} = {s_ent} THEN 1.0 ELSE 0.0 END) " + f"FROM read_parquet(?) WHERE {s_ref} IS NOT NULL AND {s_ent} IS NOT NULL", + [parquet_key], + ).fetchone() + if eq and eq[0] is not None and eq[0] > 0.5: + continue # row-wise copy -> not a hierarchy reference + refs[ref_col] = ent_col + break + finally: + con.close() + return refs + + # --------------------------------------------------------------------------- # DuckDB profiling (runs in thread) # --------------------------------------------------------------------------- @@ -172,7 +309,13 @@ def _profile_column_sync( profile["min"] = str(col_min) if col_min is not None else None profile["max"] = str(col_max) if col_max is not None else None - # Top values for low-cardinality columns (distinct ≤ 100) + # Top values for low-cardinality columns (distinct ≤ 100). + # We store the FULL distinct set (capped at 100) rather than + # just the top 5: the NL→SQL grounding/generation agents copy + # WHERE/CASE literals verbatim from these values, so a + # truncated list silently breaks any filter on a value that + # fell outside the top 5 (e.g. the P&L-line members of an + # operating-profit formula, or a brand outside the 5 biggest). if distinct_estimate <= 100 and distinct_estimate > 0: top_rows = conn.execute( f"SELECT {safe_col}, count(*) AS cnt " @@ -180,11 +323,26 @@ def _profile_column_sync( f"WHERE {safe_col} IS NOT NULL " f"GROUP BY {safe_col} " f"ORDER BY cnt DESC " - f"LIMIT 5", + f"LIMIT 100", [parquet_key], ).fetchall() profile["top_values"] = [{"value": str(r[0]), "count": r[1]} for r in top_rows] + # Subtotal detection (heuristic hint, cheap + side-effect free). + # For a low-cardinality TEXT dimension, flag values whose text + # matches a generic total/aggregate marker (en/es) or a blank + # placeholder. Such values are likely pre-aggregated rollup rows + # coexisting with detail rows; the query prompt consumes this to + # avoid mixing subtotal + detail (summing double-counts; filtering + # to the subtotal drops detail). No measure join is available + # here, so this stays purely structural/textual. + if not _is_numeric(data_type) and not _is_temporal(data_type): + subtotal_values = [ + str(r[0]) for r in top_rows if _looks_like_subtotal(str(r[0])) + ] + if subtotal_values: + profile["subtotal_values"] = subtotal_values + return profile finally: diff --git a/src/flyquery/core/services/query/query_service.py b/src/flyquery/core/services/query/query_service.py index b11ecf6..ffbd8b4 100644 --- a/src/flyquery/core/services/query/query_service.py +++ b/src/flyquery/core/services/query/query_service.py @@ -128,6 +128,33 @@ def _render_grounding_prompt( out.append(f"- `{qn}` :: {text}") out.append("") + # 2b. The "Column value catalogue" lists EVERY column with its real + # values (distinct set for low-cardinality columns; numeric + # range otherwise). This is the ground truth the agent must copy + # filter/CASE literals from -- it prevents guessing wrong + # literals (`Year IN (2023)` when the values are `FY23`), maps a + # question entity to the column whose values contain it + # (`DAPA` lives in a column's values, not a column name), and + # reveals tall/EAV layouts (P&L line items are VALUES of a single + # column) and scaled-duplicate measures (`FY` vs `FY (Real)`). + inv_columns = [h for h in inventory if (getattr(h, "metadata", {}) or {}).get("kind") == "COLUMN"] + cols_with_values = [h for h in inv_columns if (getattr(h, "metadata", {}) or {}).get("values")] + if cols_with_values: + out.append(f"# Column value catalogue ({len(cols_with_values)} columns)") + out.append( + "Real values per column. When the question names an entity (a brand, " + "year, market, category, P&L line, team…) that is NOT a column name, " + "find the column whose values contain it and filter THAT column. Copy " + "filter/CASE literals VERBATIM from these values (values may be encoded, " + "e.g. a year shown as `FY23`). If several columns share members, prefer " + "the one whose values match the question most precisely." + ) + for h in cols_with_values: + md = getattr(h, "metadata", None) or {} + qn = md.get("qualified_name") or "?" + out.append(f"- `{qn}` :: {md.get('values')}") + out.append("") + examples = bundle.get("examples", []) or [] if examples: out.append(f"# Approved Q→SQL examples ({len(examples)})") @@ -214,6 +241,13 @@ def _render_generation_prompt( inv = schema_inventory or [] inv_tables = [h for h in inv if (getattr(h, "metadata", {}) or {}).get("kind") == "TABLE"] + # Map qualified_name -> value fingerprint so the generator copies + # filter/CASE literals verbatim from real values rather than guessing. + value_index: dict[str, str] = {} + for h in inv: + md = getattr(h, "metadata", None) or {} + if md.get("kind") == "COLUMN" and md.get("values"): + value_index[md.get("qualified_name")] = md.get("values") if inv_tables: out.append(f"# Complete dataset catalogue ({len(inv_tables)} tables)") out.append( @@ -243,9 +277,11 @@ def _render_generation_prompt( out.append(f"- `{getattr(t, 'table_qualified_name', t)}`") out.append("") if g_columns: - out.append("## Columns in scope") + out.append("## Columns in scope (with real values — copy literals verbatim)") for c in g_columns: - out.append(f"- `{getattr(c, 'column_qualified_name', c)}`") + cqn = getattr(c, "column_qualified_name", c) + vals = value_index.get(cqn) + out.append(f"- `{cqn}`" + (f" :: {vals}" if vals else "")) out.append("") if g_joins: out.append("## Approved joins") @@ -257,6 +293,20 @@ def _render_generation_prompt( ) out.append("") + # Full column-value catalogue -- the grounding agent may under-select + # columns, so expose every column's real values here too. This is the + # source of truth for WHERE / CASE literals. + if value_index: + out.append(f"# Column value catalogue ({len(value_index)} columns)") + out.append( + "Real values per column. Copy filter/CASE literals VERBATIM from these. " + "If the question names an entity that is not a column name, filter the " + "column whose values contain it." + ) + for qn, vals in value_index.items(): + out.append(f"- `{qn}` :: {vals}") + out.append("") + out.append("# Task") out.append( "Generate up to N candidate DuckDB SQL queries that answer the question, " @@ -267,6 +317,36 @@ def _render_generation_prompt( "rather than `FROM orbis_companies.IVI_MALAGA_SL__Activos`.\n" "- Quote any column name that isn't a plain identifier (e.g. date-shaped " 'names like `2024-12-31` must be `"2024-12-31"`).\n' + "- Copy every WHERE / CASE / IN literal VERBATIM from the column value " + "catalogue above -- never invent or reformat a value (a year is `FY23`, " + "not `2023`; a market may be `Brazil` or `44000BR Brazil` -- use exactly " + "what is listed).\n" + "- When the metric the user names is not a column but appears among a " + "column's listed values, filter that column (tall/EAV layout): e.g. P&L " + "line items like `Total Revenue` / `Manpower` are VALUES of a single " + "category column, selected with `CASE WHEN \"\" = 'Total Revenue' …`.\n" + "- When two numeric columns are near-duplicates whose ranges differ by a " + "constant factor (~10^k), they are the same measure at different scales -- " + "prefer the larger-magnitude one for monetary sums.\n" + "- Do NOT add a WHERE filter on a dimension the question did not ask to " + "slice by -- aggregate across ALL of its values, and do NOT drop a row " + "just because a category value's name contains 'total'/'all' (those are " + "usually legitimate, often 'unallocated', buckets). Exclude a value only " + "if you can confirm it is literally the sum of the other rows.\n" + "- Return what is ASKED FOR: if the question asks for names, a list, " + "'who', 'which', or 'dame los nombres/quiénes', SELECT the identifying " + "column(s) (e.g. the name) and return the matching ROWS -- do NOT collapse " + "to a COUNT. Use COUNT/aggregates only when a count or total is requested. " + "If BOTH a count and the names are asked, return the names (the count is " + "derivable from the row count).\n" + "- HIERARCHY questions: when the question asks about a person's TEAM, " + "direct reports, the people 'at their charge' / under them, their org, or " + "movements in THEIR structure, it is a self-referencing hierarchy. A " + "column marked 'HIERARCHY: holds entities/people from column X' holds each " + "row's manager/owner. The person's team = the ROWS where such a column " + "equals that person (filter it with case-insensitive LIKE '%name%'), NOT " + "the person's own row. Try EVERY hierarchy column (a person may appear in " + "more than one), and match names tolerantly (accents/spacing).\n" "- Be a SINGLE statement (no multi-statement; no DDL).\n" "- Be a SELECT (DuckDB-flavored)." ) @@ -542,7 +622,11 @@ async def answer( # into the persisted query record for reproducibility. metric_name = grounded.metrics[0].metric_name compiled, metric_version = await self._compiled_metric_sql( - metric_name, dataset_id, tenant_id=tenant_id, workspace_id=workspace_id + metric_name, + dataset_id, + tenant_id=tenant_id, + workspace_id=workspace_id, + extra_filter=getattr(grounded.metrics[0], "extra_filter", None), ) if compiled: chosen_sql = compiled @@ -577,7 +661,17 @@ async def answer( gen_run = await self._generation_agent.run(gen_prompt) gen_out = getattr(gen_run, "output", gen_run) candidates_json = [c.model_dump() for c in gen_out.candidates] - chosen_sql = gen_out.candidates[0].sql + # Don't blindly take the highest-confidence candidate: probe them + # (DuckDB only, no extra LLM) and prefer one that passes the firewall + # and returns non-empty, non-degenerate rows. Empty candidate list + # degrades to "" (FAILED) instead of raising an IndexError. + chosen_sql = await self._select_best_candidate( + [c.sql for c in gen_out.candidates if c.sql], + dataset_id=dataset_id, + scopes=scopes, + dataset_allowlist=dataset_allowlist, + pins=prior_snapshot_pins, + ) # ------------------------------------------------------------------ # 5. AST classify + scope guard @@ -666,6 +760,7 @@ async def answer( attached = await self._table_resolver.resolve( dataset_id, list(ast.table_refs), + pins=prior_snapshot_pins, ) result = await self._executor.execute(chosen_sql, attached) retries = 0 @@ -715,7 +810,9 @@ async def answer( ) retries += 1 continue - attached = await self._table_resolver.resolve(dataset_id, list(ast.table_refs)) + attached = await self._table_resolver.resolve( + dataset_id, list(ast.table_refs), pins=prior_snapshot_pins + ) result = await self._executor.execute(chosen_sql, attached) retries += 1 @@ -747,6 +844,24 @@ async def answer( # 9. Clarification frame (emitted alongside answer when confidence is low) # ------------------------------------------------------------------ clarification = self._clarification(grounded) + # A syntactically-valid query that returns 0 rows (or a single + # all-NULL/zero aggregate) is suspicious: the usual cause is a + # filter literal that doesn't match how the data is encoded. Rather + # than report it as a confident empty answer, surface a clarification + # and downgrade confidence so the caller knows to verify. + suspicious_empty = isinstance(result, ExecutionResult) and self._is_suspicious_empty(result) + if clarification is None and suspicious_empty: + from flyquery.interfaces.query import ClarificationFrame + + clarification = ClarificationFrame( + questions=[ + "The query executed successfully but returned no matching data " + "(0 rows / empty result). The filter values may not match how the " + "data is encoded -- please verify the exact column values (e.g. " + "category labels or period format) or rephrase the question." + ], + reasons=[], + ) clarification_emitted = clarification is not None # ------------------------------------------------------------------ @@ -796,7 +911,12 @@ async def answer( # ------------------------------------------------------------------ # 12. Auto-learn (only on first-shot OK + no PII + no clarification) # ------------------------------------------------------------------ - if execution_status == "OK" and isinstance(result, ExecutionResult) and not clarification_emitted: + if ( + execution_status == "OK" + and isinstance(result, ExecutionResult) + and result.row_count > 0 + and not clarification_emitted + ): await self._auto_learner.maybe_propose( tenant_id=tenant_id, workspace_id=workspace_id, @@ -811,6 +931,16 @@ async def answer( # ------------------------------------------------------------------ # 13. Persist conversation turn (Phase E drill-down) # ------------------------------------------------------------------ + # THIS turn's snapshot pins: the snapshot each resolved table was + # answered against. Tables already pinned by an earlier turn keep + # their pin (prior wins); newly-referenced tables pin to current. + # Persisting THIS turn's pins (not the prior turn's) is what makes + # drill-down reproducible across a mid-conversation re-ingest. + this_turn_pins: dict[str, str] = { + **(await self._table_resolver.current_snapshots(dataset_id, list(ast.table_refs))), + **prior_snapshot_pins, + } + if ( conversation_id is not None and self._conversation_service is not None @@ -824,7 +954,7 @@ async def answer( executed_sql=chosen_sql, summary=explanation_obj.summary if explanation_obj else None, table_qnames_json=list(ast.table_refs), - snapshot_pins_json=prior_snapshot_pins, + snapshot_pins_json=this_turn_pins, elapsed_ms=elapsed, ) @@ -839,13 +969,85 @@ async def answer( chart_hint=explanation_obj.chart_hint if explanation_obj else None, explanation=explanation_obj.summary if explanation_obj else None, clarification=clarification, - grounded_summary=self._grounded_summary(grounded), + grounded_summary=self._grounded_summary( + grounded, confidence_cap=0.4 if suspicious_empty else None + ), + snapshot_pins=this_turn_pins, ) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ + @staticmethod + def _is_suspicious_empty(result) -> bool: + """True when an executed result is empty/degenerate enough to doubt. + + Catches the canonical wrong-literal symptom: a valid query that + matched nothing (0 rows), or a single-row single-column aggregate + whose only value is NULL / 0 / 0.0 (e.g. a SUM/CASE pivot where + every branch missed). + """ + if result.row_count == 0: + return True + rows = getattr(result, "rows", None) or [] + if result.row_count == 1 and len(rows) == 1 and isinstance(rows[0], dict) and len(rows[0]) == 1: + (only_value,) = rows[0].values() + return only_value is None or only_value == 0 + return False + + async def _select_best_candidate( + self, + candidate_sqls: list[str], + *, + dataset_id: uuid.UUID, + scopes: set[str], + dataset_allowlist: set[uuid.UUID] | None, + pins: dict[str, str], + ) -> str: + """Pick the candidate SQL that best answers the question. + + Generation emits N candidates ranked by self-reported confidence, but + the top one sometimes filters on the wrong column (or under-searches a + set of hierarchy columns) and returns 0 rows while a lower-ranked + candidate is correct. So we probe the candidates and prefer the first + that (a) passes the firewall, (b) executes, and (c) returns non-empty, + non-degenerate rows. This is DuckDB-only -- NO extra LLM calls -- and + general: it just prefers a candidate that actually returns data. + + Falls back to the first candidate that executed at all, else the first + candidate (so the existing scope/critic handling downstream is + unchanged when nothing is clearly better). + """ + if len(candidate_sqls) <= 1: + return candidate_sqls[0] if candidate_sqls else "" + + first_executed: str | None = None + for sql in candidate_sqls: + ast = self._ast_classifier.classify(sql) + table_kinds = await self._table_kinds_by_name(list(ast.table_refs), dataset_id) + dataset_of_table = await self._dataset_of_tables(list(ast.table_refs), dataset_id) + try: + self._scope_guard.check( + classification=ast, + scopes=scopes, + table_kinds_by_name=table_kinds, + dataset_allowlist=dataset_allowlist, + dataset_of_table=dataset_of_table, + ) + except ScopeGuardError: + continue # unsafe candidate -- skip + if sorted({t for t in ast.table_refs if t} - set(table_kinds.keys())): + continue # references a table not in the dataset -- skip + attached = await self._table_resolver.resolve(dataset_id, list(ast.table_refs), pins=pins) + result = await self._executor.execute(sql, attached) + if isinstance(result, ExecutionResult): + if not self._is_suspicious_empty(result): + return sql # passes firewall + returns real rows -- best + if first_executed is None: + first_executed = sql # remember first successful-but-empty + return first_executed or candidate_sqls[0] + async def _table_kinds_by_name( self, table_names: list[str], @@ -886,11 +1088,19 @@ async def _compiled_metric_sql( *, tenant_id: str, workspace_id: uuid.UUID, + extra_filter: str | None = None, ) -> tuple[str | None, int | None]: """Fetch + bind the compiled SQL for a PUBLISHED metric. Returns ``(bound_sql, current_version)`` so the version can be pinned in the query record, or ``(None, None)`` when no usable metric is found. + + ``extra_filter`` is the per-question slice the grounding agent derived + (e.g. ``Market = 'Brazil' AND Year = 'FY24'``) to be appended to the + metric's WHERE via the compiler's ``{extra_filter_clause}`` slot. It is + an LLM-supplied predicate, so it is re-run through the publish-time + firewall before binding; an unsafe filter is dropped (the metric still + returns its unfiltered value) rather than executed. """ if self._semantic_repo is None: return None, None @@ -903,9 +1113,31 @@ async def _compiled_metric_sql( return None, None if not row or not row.get("compiled_sql_template"): return None, None - bound = SemanticCompiler.bind(row["compiled_sql_template"]) + + safe_filter = self._firewall_extra_filter(row["compiled_sql_template"], extra_filter) + bound = SemanticCompiler.bind(row["compiled_sql_template"], extra_filter=safe_filter) return bound, row.get("current_version") + @staticmethod + def _firewall_extra_filter(template: str, extra_filter: str | None) -> str | None: + """Validate an LLM-supplied metric filter via the publish-time firewall. + + Returns the filter when the bound SQL passes ``assert_safe_template``, + else ``None`` (filter dropped). Defensive: any firewall/parse failure + also drops the filter rather than risking an unsafe predicate. + """ + if not extra_filter: + return None + try: + from flyquery.core.services.semantic.firewall import assert_safe_template + + probe = SemanticCompiler.bind(template, extra_filter=extra_filter) + assert_safe_template(probe) + return extra_filter + except Exception as exc: # noqa: BLE001 -- any failure → drop the filter + logger.warning("dropping unsafe semantic extra_filter %r: %s", extra_filter, exc) + return None + def _clarification(self, grounded) -> Any: """Build a ClarificationFrame if grounding confidence is low.""" from flyquery.interfaces.query import ClarificationFrame @@ -917,11 +1149,19 @@ def _clarification(self, grounded) -> Any: ) return None - def _grounded_summary(self, grounded) -> dict: - """Convert GroundedContext to a summary dict for the response.""" + def _grounded_summary(self, grounded, confidence_cap: float | None = None) -> dict: + """Convert GroundedContext to a summary dict for the response. + + ``confidence_cap`` lets the caller lower the reported confidence when + the executed result is suspicious (e.g. 0 rows from a wrong literal), + so a confidently-wrong empty answer is not surfaced at high confidence. + """ + confidence = grounded.confidence + if confidence_cap is not None: + confidence = min(confidence, confidence_cap) return { "path": grounded.path, - "confidence": grounded.confidence, + "confidence": confidence, "table_count": len(grounded.tables), "missing_info": grounded.missing_info, } diff --git a/src/flyquery/core/services/retrieval/reranker.py b/src/flyquery/core/services/retrieval/reranker.py index d3a9fa4..6c117f7 100644 --- a/src/flyquery/core/services/retrieval/reranker.py +++ b/src/flyquery/core/services/retrieval/reranker.py @@ -21,10 +21,16 @@ from __future__ import annotations +import logging from typing import Any, Protocol from flyquery.core.services.retrieval.search_index import Hit +logger = logging.getLogger(__name__) + +# Guard so the "reranking disabled" warning is emitted at most once per process. +_warned_noop_fallback = False + class Reranker(Protocol): """Protocol for a reranking step in the retrieval pipeline.""" @@ -88,10 +94,21 @@ def build_reranker(settings: Any) -> NoopReranker | CrossEncoderReranker: :param settings: ``FlyquerySettings`` instance :return: a ready-to-use reranker """ + global _warned_noop_fallback model_name = getattr(settings, "reranker_model", "") or "" if not model_name: return NoopReranker() try: return CrossEncoderReranker(model_name) - except Exception: # noqa: BLE001 + except Exception as exc: # noqa: BLE001 + if not _warned_noop_fallback: + _warned_noop_fallback = True + logger.warning( + "reranker model=%s unavailable (%s) -- falling back to NoopReranker. " + "Relevance reranking is DISABLED; results are truncated by retrieval " + "order only (install sentence-transformers / make the cross-encoder " + "model loadable to enable it).", + model_name, + exc, + ) return NoopReranker() diff --git a/src/flyquery/core/services/retrieval/search_index.py b/src/flyquery/core/services/retrieval/search_index.py index 8d10f5c..836eed0 100644 --- a/src/flyquery/core/services/retrieval/search_index.py +++ b/src/flyquery/core/services/retrieval/search_index.py @@ -42,6 +42,115 @@ class Hit: metadata: dict = field(default_factory=dict) +def value_fingerprint( + data_type: str | None, + sample_values_json: object, + profile_json: object, + *, + max_values: int = 40, + max_chars: int = 400, +) -> str: + """Compact, human-readable summary of a column's ACTUAL values. + + Surfaced into the grounding/generation prompts so the agents copy + WHERE / CASE literals verbatim from real values instead of guessing + -- e.g. a fiscal year stored as ``FY23`` (not ``2023``), the members + of a tall/EAV category column (``P&L Line System`` rows like + ``Total Revenue`` / ``Manpower``), or the magnitude gap between a + scaled-duplicate measure (``FY`` ~0.02 vs ``FY (Real)`` ~25748). + + Returns ``""`` when there is nothing useful to show. + """ + prof = profile_json if isinstance(profile_json, dict) else {} + + # NOTE: profiling stores ``subtotal_values`` (name-based candidates for + # pre-aggregated rows). We deliberately do NOT surface them as an + # exclusion directive: a "Total_*" value in a dimension is just as often a + # legitimate additive bucket (e.g. unallocated/corporate) as a true rollup, + # and any hint makes the agent wrongly drop it. Whether to exclude requires + # the structural test (does the value's aggregate == the sum of the others?) + # which is not available per-column at profile time. The general "aggregate + # across ALL values / don't drop a value" prompt rule handles this safely. + subtotal_note = "" + + # Self-referencing hierarchy hint: this column's values are entities from + # another (higher-cardinality) column -- e.g. a manager column whose values + # are people from the employee column. Surfaced even for high-cardinality + # columns that have no listable values, because that is exactly when the + # agent cannot otherwise tell who a person reports to. + ref_col = prof.get("references_column") + ref_note = ( + f" | HIERARCHY: holds entities/people from column '{ref_col}' (e.g. each row's " + f"manager/owner/parent). To get a given person's group/team/reports, filter THIS " + f"column to that person (case-insensitive LIKE), not the person's own row." + if ref_col + else "" + ) + + # Categorical: the stored distinct value set (low-cardinality columns). + top_values = prof.get("top_values") or [] + if top_values: + seen: set[str] = set() + uniq: list[str] = [] + for tv in top_values: + v = tv.get("value") if isinstance(tv, dict) else tv + if v is None: + continue + s = str(v) + if s not in seen: + seen.add(s) + uniq.append(s) + shown = uniq[:max_values] + body = " | ".join(shown) + if len(body) > max_chars: + body = body[:max_chars].rsplit("|", 1)[0].strip() + " | …" + more = "" if len(uniq) <= len(shown) else f" (+{len(uniq) - len(shown)} more)" + return f"values: {body}{more}{subtotal_note}{ref_note}" if body else ref_note.strip(" |") + + # Numeric / temporal: range + cardinality (exposes scaled duplicates). + col_min, col_max = prof.get("min"), prof.get("max") + if col_min is not None or col_max is not None: + rng = f"range: {col_min} .. {col_max}" + dist = prof.get("distinct_estimate") + if dist is not None: + rng += f" (~{dist} distinct)" + return rng + ref_note + + # Fallback: a few raw sample values (high-cardinality columns). + samples = sample_values_json if isinstance(sample_values_json, list) else [] + if samples: + seen2: set[str] = set() + uniq2: list[str] = [] + for v in samples: + s = str(v) + if s not in seen2: + seen2.add(s) + uniq2.append(s) + if uniq2: + return "e.g.: " + " | ".join(uniq2[:8]) + ref_note + return ref_note.strip(" |") + + +def _column_hit(r, score: float) -> "Hit": + """Build a ranked schema-object Hit, enriched with a value fingerprint. + + Used by the BM25 + vector column searches so the ranked "Top-ranked + column matches" the grounding agent sees carry real values, not just + name + description. + """ + fp = value_fingerprint(r.data_type, r.sample_values_json, r.profile_json) + text = f"{r.qualified_name}: {r.data_type}\n{r.description or ''}" + if fp: + text += f"\n{fp}" + return Hit( + source_kind="schema_object", + id=r.id, + text=text, + score=score, + metadata={"qualified_name": r.qualified_name, "table_id": str(r.table_id), "values": fp}, + ) + + class SearchIndex: """Read-only query helpers that operate on a shared ``AsyncSession``.""" @@ -60,10 +169,12 @@ async def bm25_schema_objects(self, query: str, dataset_id: uuid.UUID, limit: in sa.text( """ SELECT o.id, o.qualified_name, o.description, o.data_type, o.table_id, + o.sample_values_json, o.profile_json, ts_rank(o.content_tsv, plainto_tsquery('english', :q)) AS score FROM flyquery_schema_objects o JOIN flyquery_tables t ON t.id = o.table_id WHERE t.dataset_id = :ds AND o.is_active = true + AND o.snapshot_id = t.current_snapshot_id AND o.content_tsv @@ plainto_tsquery('english', :q) ORDER BY score DESC LIMIT :lim @@ -71,16 +182,7 @@ async def bm25_schema_objects(self, query: str, dataset_id: uuid.UUID, limit: in ), {"q": query, "ds": dataset_id, "lim": limit}, ) - return [ - Hit( - source_kind="schema_object", - id=r.id, - text=f"{r.qualified_name}: {r.data_type}\n{r.description or ''}", - score=float(r.score), - metadata={"qualified_name": r.qualified_name, "table_id": str(r.table_id)}, - ) - for r in rows.mappings() - ] + return [_column_hit(r, float(r.score)) for r in rows.mappings()] async def vector_schema_objects( self, @@ -99,26 +201,19 @@ async def vector_schema_objects( sa.text( """ SELECT o.id, o.qualified_name, o.description, o.data_type, o.table_id, + o.sample_values_json, o.profile_json, 1 - (o.embedding <=> CAST(:emb AS vector)) AS score FROM flyquery_schema_objects o JOIN flyquery_tables t ON t.id = o.table_id - WHERE t.dataset_id = :ds AND o.is_active = true AND o.embedding IS NOT NULL + WHERE t.dataset_id = :ds AND o.is_active = true + AND o.snapshot_id = t.current_snapshot_id AND o.embedding IS NOT NULL ORDER BY o.embedding <=> CAST(:emb AS vector) LIMIT :lim """ ), {"emb": str(query_embedding), "ds": dataset_id, "lim": limit}, ) - return [ - Hit( - source_kind="schema_object", - id=r.id, - text=f"{r.qualified_name}: {r.data_type}\n{r.description or ''}", - score=float(r.score), - metadata={"qualified_name": r.qualified_name, "table_id": str(r.table_id)}, - ) - for r in rows.mappings() - ] + return [_column_hit(r, float(r.score)) for r in rows.mappings()] async def all_schema_objects( self, @@ -158,7 +253,9 @@ async def all_schema_objects( ON c.table_id = o.table_id AND c.kind = 'COLUMN' AND c.is_active = true + AND c.snapshot_id = t.current_snapshot_id WHERE t.dataset_id = :ds AND o.is_active = true AND o.kind = 'TABLE' + AND o.snapshot_id = t.current_snapshot_id GROUP BY o.id, o.qualified_name, o.description, o.table_id, o.kind ORDER BY o.qualified_name """ @@ -175,10 +272,12 @@ async def all_schema_objects( await self._session.execute( sa.text( """ - SELECT o.id, o.qualified_name, o.description, o.data_type, o.table_id, o.kind + SELECT o.id, o.qualified_name, o.description, o.data_type, o.table_id, o.kind, + o.sample_values_json, o.profile_json FROM flyquery_schema_objects o JOIN flyquery_tables t ON t.id = o.table_id WHERE t.dataset_id = :ds AND o.is_active = true AND o.kind = 'COLUMN' + AND o.snapshot_id = t.current_snapshot_id ORDER BY o.qualified_name LIMIT :lim """ @@ -223,16 +322,21 @@ async def all_schema_objects( ) for r in column_rows: + fp = value_fingerprint(r["data_type"], r["sample_values_json"], r["profile_json"]) + text = f"{r['qualified_name']}: {r['data_type'] or ''}\n{r['description'] or ''}" + if fp: + text += f"\n{fp}" hits.append( Hit( source_kind="schema_object", id=r["id"], - text=f"{r['qualified_name']}: {r['data_type'] or ''}\n{r['description'] or ''}", + text=text, score=1.0, metadata={ "qualified_name": r["qualified_name"], "table_id": str(r["table_id"]), "kind": "COLUMN", + "values": fp, }, ) ) diff --git a/src/flyquery/web/controllers/query_controller.py b/src/flyquery/web/controllers/query_controller.py index 93169d6..0ddac2b 100644 --- a/src/flyquery/web/controllers/query_controller.py +++ b/src/flyquery/web/controllers/query_controller.py @@ -51,7 +51,12 @@ from flyquery.core.services.execution.scope_guard import ScopeGuard, ScopeGuardError from flyquery.core.services.execution.table_resolver import TableResolver from flyquery.core.services.query.query_repository import QueryRepository -from flyquery.core.services.query.query_service import QueryService +from flyquery.core.services.query.query_service import ( + QueryService, + _render_critic_prompt, + _render_generation_prompt, + _render_grounding_prompt, +) from flyquery.core.services.query.result_uploader import ResultUploader from flyquery.core.services.retrieval.embedder import Embedder from flyquery.core.services.retrieval.hybrid_retriever import HybridRetriever @@ -131,6 +136,54 @@ def __init__( self._scope_guard = ScopeGuard() self._executor = DuckDBExecutor(settings) + async def _guarded_execute( + self, + sql: str, + ast: Any, + table_kinds: dict[str, str], + resolver: TableResolver, + dataset_id: uuid.UUID, + bundle: dict, + ) -> Any: + """Apply ScopeGuard + bad-tables firewall, then execute. + + Mirrors the guards in ``QueryService.answer`` so the streaming path + enforces the same dataset isolation and table-existence checks. + Returns an ``ExecutionResult`` or an ``ExecutionError`` (which the + caller's critic loop can attempt to refine). + """ + from flyquery.core.services.execution.duckdb_executor import ExecutionError + + try: + self._scope_guard.check( + classification=ast, + scopes=_DEFAULT_USER_SCOPES, + table_kinds_by_name=table_kinds, + dataset_allowlist=None, + dataset_of_table={t: str(dataset_id) for t in ast.table_refs}, + ) + except ScopeGuardError as exc: + return ExecutionError(message=f"Rejected by firewall: {exc}") + + ref_set = {t for t in ast.table_refs if t} + bad_tables = sorted(ref_set - set(table_kinds.keys())) + if bad_tables: + real_tables = sorted(table_kinds.keys()) + [ + (getattr(h, "metadata", {}) or {}).get("qualified_name", "").rsplit(".", 1)[-1] + for h in (bundle.get("schema_inventory") or []) + if (getattr(h, "metadata", {}) or {}).get("kind") == "TABLE" + ] + real_tables = [t for t in dict.fromkeys(real_tables) if t] + return ExecutionError( + message=( + f"Table(s) {bad_tables!r} do not exist in this dataset. " + f"Pick ONLY from: {real_tables[:80]!r}." + ) + ) + + attached = await resolver.resolve(dataset_id, list(ast.table_refs)) + return await self._executor.execute(sql, attached) + def _build_service(self, db_session: AsyncSession) -> QueryService: """Build a per-request QueryService around the provided session.""" index = SearchIndex(db_session) @@ -335,24 +388,30 @@ async def explain( bundle["schema_objects"] = reranked grounding_agent = build_grounding_agent(self._settings) - grounded = await grounding_agent.run( - {"question": body.question, "bundle": bundle, "starting_point_sql": None} + grounded_run = await grounding_agent.run( + _render_grounding_prompt( + question=body.question, bundle=bundle, starting_point_sql=None + ) ) + grounded = getattr(grounded_run, "output", grounded_run) generation_agent = build_generation_agent(self._settings) - gen_out = await generation_agent.run( - {"grounded": grounded, "question": body.question, "starting_point_sql": None} + gen_run = await generation_agent.run( + _render_generation_prompt( + body.question, grounded, None, schema_inventory=bundle.get("schema_inventory") + ) ) - candidate = gen_out.candidates[0] + gen_out = getattr(gen_run, "output", gen_run) + candidate = gen_out.candidates[0] if gen_out.candidates else None clarification: ClarificationFrame | None = None if grounded.confidence < self._settings.grounding_min_confidence and grounded.missing_info: clarification = ClarificationFrame(questions=grounded.missing_info, reasons=[]) return ExplainResponse( - sql=candidate.sql, - reasoning=candidate.reasoning, - confidence=candidate.confidence, + sql=candidate.sql if candidate else "", + reasoning=candidate.reasoning if candidate else "generation produced no candidate", + confidence=candidate.confidence if candidate else 0.0, grounded_summary={ "path": grounded.path, "confidence": grounded.confidence, @@ -398,15 +457,21 @@ async def validate( bundle["schema_objects"] = reranked grounding_agent = build_grounding_agent(self._settings) - grounded = await grounding_agent.run( - {"question": body.question, "bundle": bundle, "starting_point_sql": None} + grounded_run = await grounding_agent.run( + _render_grounding_prompt( + question=body.question, bundle=bundle, starting_point_sql=None + ) ) + grounded = getattr(grounded_run, "output", grounded_run) generation_agent = build_generation_agent(self._settings) - gen_out = await generation_agent.run( - {"grounded": grounded, "question": body.question, "starting_point_sql": None} + gen_run = await generation_agent.run( + _render_generation_prompt( + body.question, grounded, None, schema_inventory=bundle.get("schema_inventory") + ) ) - chosen_sql = gen_out.candidates[0].sql + gen_out = getattr(gen_run, "output", gen_run) + chosen_sql = gen_out.candidates[0].sql if gen_out.candidates else "" ast = self._ast_classifier.classify(chosen_sql) @@ -494,21 +559,28 @@ async def _stream_events( retriever = HybridRetriever(index=index, embedder=self._embedder, rrf_k=self._settings.rrf_k) reranker = build_reranker(self._settings) - # Stage 1: retrieve + ground + # Stage 1: retrieve + ground (same retrieval params + rendered + # prompt as the sync POST /query path, so the streaming path is + # not blind to the column-value catalogue, examples, metrics). bundle = await retriever.retrieve( request.question, dataset_id=request.dataset_id, workspace_id=workspace_id, top_k_schema=self._settings.top_k_schema * 3, + top_k_examples=self._settings.top_k_examples, + top_k_metrics=self._settings.top_k_metrics, ) schema_hits = bundle.get("schema_objects", []) reranked = await reranker.rerank(request.question, schema_hits, top_n=self._settings.top_k_schema) bundle["schema_objects"] = reranked grounding_agent = build_grounding_agent(self._settings) - grounded = await grounding_agent.run( - {"question": request.question, "bundle": bundle, "starting_point_sql": None} + grounded_run = await grounding_agent.run( + _render_grounding_prompt( + question=request.question, bundle=bundle, starting_point_sql=None + ) ) + grounded = getattr(grounded_run, "output", grounded_run) yield _sse_frame( "schema_linked", @@ -537,11 +609,14 @@ async def _stream_events( # Stage 3: generate SQL generation_agent = build_generation_agent(self._settings) - gen_out = await generation_agent.run( - {"grounded": grounded, "question": request.question, "starting_point_sql": None} + gen_run = await generation_agent.run( + _render_generation_prompt( + request.question, grounded, None, schema_inventory=bundle.get("schema_inventory") + ) ) + gen_out = getattr(gen_run, "output", gen_run) candidates = gen_out.candidates - chosen_sql = candidates[0].sql + chosen_sql = candidates[0].sql if candidates else "" yield _sse_frame( "sql_generated", @@ -554,28 +629,38 @@ async def _stream_events( }, ) - # Stage 4: execute (with critic loop) + # Stage 4: AST + firewall/scope guards + execute (with critic loop). + # _guarded_execute applies the SAME ScopeGuard + bad-tables firewall + # the sync path enforces, so streaming callers cannot bypass dataset + # isolation or run SQL against a non-existent/cross-dataset table. ast = self._ast_classifier.classify(chosen_sql) table_resolver = TableResolver(session=db_session, settings=self._settings) - attached = await table_resolver.resolve(request.dataset_id, list(ast.table_refs)) - - exec_result = await self._executor.execute(chosen_sql, attached) + table_kinds = await _table_kinds_by_name(db_session, list(ast.table_refs), request.dataset_id) + exec_result = await self._guarded_execute( + chosen_sql, ast, table_kinds, table_resolver, request.dataset_id, bundle + ) retries = 0 while isinstance(exec_result, ExecutionError) and retries < self._settings.max_refine_retries: critic_agent = build_critic_agent(self._settings) - refined = await critic_agent.run( - { - "sql": chosen_sql, - "error": exec_result.message, - "grounded": grounded, - "question": request.question, - } + refined_run = await critic_agent.run( + _render_critic_prompt( + question=request.question, + failing_sql=chosen_sql, + error_message=exec_result.message, + grounded=grounded, + schema_inventory=bundle.get("schema_inventory"), + ) ) + refined = getattr(refined_run, "output", refined_run) chosen_sql = refined.sql ast = self._ast_classifier.classify(chosen_sql) - attached = await table_resolver.resolve(request.dataset_id, list(ast.table_refs)) - exec_result = await self._executor.execute(chosen_sql, attached) + table_kinds = await _table_kinds_by_name( + db_session, list(ast.table_refs), request.dataset_id + ) + exec_result = await self._guarded_execute( + chosen_sql, ast, table_kinds, table_resolver, request.dataset_id, bundle + ) retries += 1 snapshot_pins: dict = {} @@ -705,6 +790,26 @@ async def _stream_events( yield _sse_frame("final", answer.model_dump(mode="json")) +async def _table_kinds_by_name( + session: AsyncSession, table_names: list[str], dataset_id: uuid.UUID +) -> dict[str, str]: + """Return {name: kind} for the active tables in the dataset (firewall input).""" + if not table_names: + return {} + import sqlalchemy as sa + + rows = await session.execute( + sa.text( + """ + SELECT name, kind FROM flyquery_tables + WHERE dataset_id = :ds AND name = ANY(:names) AND is_active = true + """ + ), + {"ds": dataset_id, "names": list(table_names)}, + ) + return {r["name"]: r["kind"] for r in rows.mappings()} + + def _now_ms() -> int: """Return current monotonic time in milliseconds.""" import time From 277f96ae6cf9b56f19cc0936900acdbee26575d8 Mon Sep 17 00:00:00 2001 From: Jose Agustin Puente Date: Tue, 2 Jun 2026 10:28:09 +0200 Subject: [PATCH 2/4] fix(ci): satisfy PR gate (lockstep, lint, unit/integration tests) - builder.py is a lock-step (canon-pinned) module; revert it and pass the determinism setting via the existing extra_settings={"temperature": 0.0} instead of a new parameter, so the lockstep gate stays green - move the table-kinds lookup out of the query controller into TableResolver (service layer) so the "no raw SQL in controllers" gate passes - ruff: inline a negated return (SIM103), drop a forward-ref quote (UP037), and apply ruff format - tests: extend the fake TableResolvers with pins / current_snapshots / table_kinds_by_name; seed current_snapshot_id in the retrieval integration fixture (retrieval is now scoped to current_snapshot_id, as publish sets it) --- src/flyquery/core/agents/builder.py | 10 ----- .../core/agents/column_name_proposer_agent.py | 2 +- src/flyquery/core/agents/describe_agent.py | 2 +- .../core/agents/rename_detection_agent.py | 2 +- .../core/services/execution/ast_classifier.py | 4 +- .../core/services/execution/table_resolver.py | 21 ++++++++-- .../ingestion/readers/excel_reader.py | 13 +++---- .../core/services/ingestion/stages/profile.py | 4 +- .../core/services/retrieval/search_index.py | 2 +- .../web/controllers/query_controller.py | 38 +++---------------- tests/integration/test_hybrid_retrieval.py | 8 ++++ tests/unit/test_query_controller_sse.py | 8 +++- tests/unit/test_query_service.py | 5 ++- 13 files changed, 54 insertions(+), 65 deletions(-) diff --git a/src/flyquery/core/agents/builder.py b/src/flyquery/core/agents/builder.py index a15e048..3044bfb 100644 --- a/src/flyquery/core/agents/builder.py +++ b/src/flyquery/core/agents/builder.py @@ -48,7 +48,6 @@ def build_agent( instructions: str, settings: FlyquerySettings, max_output_tokens: int | None = None, - temperature: float | None = None, extra_settings: dict[str, Any] | None = None, ) -> Any: """Construct a :class:`FireflyAgent` with the standard knobs. @@ -68,13 +67,6 @@ def build_agent( requiring callers to plumb the env var themselves. max_output_tokens: Optional override for this specific call. ``None`` falls back to ``settings.agent_max_output_tokens``. - temperature: Optional sampling temperature folded into - ``model_settings``. ``None`` (the default) leaves the - provider default untouched -- grounding / generation / - critic rely on that diversity (generation samples N - candidates). Naming / description stages pass ``0.0`` so - near-identical re-ingests yield identical names and avoid - schema-change churn. extra_settings: Optional extra ``model_settings`` entries. Caller-provided keys WIN on conflict so a stage can cap below the global budget (e.g. a 1-token classifier). @@ -91,8 +83,6 @@ def build_agent( resolved_max = resolve_max_output_tokens(settings, override=max_output_tokens) model_settings: dict[str, Any] = {"max_tokens": resolved_max} - if temperature is not None: - model_settings["temperature"] = temperature if extra_settings: # Caller-provided settings win on conflict -- a stage can cap # itself below the default by passing ``max_tokens=128`` in diff --git a/src/flyquery/core/agents/column_name_proposer_agent.py b/src/flyquery/core/agents/column_name_proposer_agent.py index e317931..800e259 100644 --- a/src/flyquery/core/agents/column_name_proposer_agent.py +++ b/src/flyquery/core/agents/column_name_proposer_agent.py @@ -118,7 +118,7 @@ def build_column_name_proposer_agent(settings): settings=settings, # Deterministic naming: identical re-ingests must yield identical # column names, otherwise reconcile sees phantom schema churn. - temperature=0.0, + extra_settings={"temperature": 0.0}, ) diff --git a/src/flyquery/core/agents/describe_agent.py b/src/flyquery/core/agents/describe_agent.py index 4333e38..0adf064 100644 --- a/src/flyquery/core/agents/describe_agent.py +++ b/src/flyquery/core/agents/describe_agent.py @@ -142,5 +142,5 @@ def build_describe_agent(settings): settings=settings, # Deterministic descriptions/semantic types: identical columns on # re-ingest must produce identical metadata, no schema-change churn. - temperature=0.0, + extra_settings={"temperature": 0.0}, ) diff --git a/src/flyquery/core/agents/rename_detection_agent.py b/src/flyquery/core/agents/rename_detection_agent.py index dd951a3..b17791d 100644 --- a/src/flyquery/core/agents/rename_detection_agent.py +++ b/src/flyquery/core/agents/rename_detection_agent.py @@ -64,5 +64,5 @@ def build_rename_detection_agent(settings): max_output_tokens=2048, # Deterministic: the same removed/candidate pair must always # resolve the same way so re-ingests don't flip-flop renames. - temperature=0.0, + extra_settings={"temperature": 0.0}, ) diff --git a/src/flyquery/core/services/execution/ast_classifier.py b/src/flyquery/core/services/execution/ast_classifier.py index dfcd083..5402588 100644 --- a/src/flyquery/core/services/execution/ast_classifier.py +++ b/src/flyquery/core/services/execution/ast_classifier.py @@ -83,9 +83,7 @@ def classify(self, sql: str) -> AstClassification: # ``table_refs``; the downstream bad-tables guard then flags it # as a non-existent table and the (otherwise valid) query is # rejected — see QueryService bad-tables set-difference. - cte_names = { - cte.alias_or_name for cte in stmt.find_all(sqlglot.expressions.CTE) if cte.alias_or_name - } + cte_names = {cte.alias_or_name for cte in stmt.find_all(sqlglot.expressions.CTE) if cte.alias_or_name} tables = tuple( sorted( { diff --git a/src/flyquery/core/services/execution/table_resolver.py b/src/flyquery/core/services/execution/table_resolver.py index 0c18785..8650ba5 100644 --- a/src/flyquery/core/services/execution/table_resolver.py +++ b/src/flyquery/core/services/execution/table_resolver.py @@ -102,9 +102,24 @@ async def resolve( out[r["name"]] = f"{base}/{key}" return out - async def current_snapshots( - self, dataset_id: uuid.UUID, table_names: list[str] - ) -> dict[str, str]: + async def table_kinds_by_name(self, dataset_id: uuid.UUID, table_names: list[str]) -> dict[str, str]: + """Return ``{name: kind}`` for the active tables in the dataset. + + Used by the firewall/bad-tables guard. Lives here (service layer) + rather than in a controller so the raw SQL stays out of the web tier. + """ + if not table_names: + return {} + rows = await self._session.execute( + sa.text(""" + SELECT name, kind FROM flyquery_tables + WHERE dataset_id = :ds AND name = ANY(:names) AND is_active = true + """), + {"ds": dataset_id, "names": list(table_names)}, + ) + return {r["name"]: r["kind"] for r in rows.mappings()} + + async def current_snapshots(self, dataset_id: uuid.UUID, table_names: list[str]) -> dict[str, str]: """Return ``{table_name: current_snapshot_id}`` for the given tables. Used to record THIS turn's snapshot pins so a later drill-down turn diff --git a/src/flyquery/core/services/ingestion/readers/excel_reader.py b/src/flyquery/core/services/ingestion/readers/excel_reader.py index c909fac..91c943c 100644 --- a/src/flyquery/core/services/ingestion/readers/excel_reader.py +++ b/src/flyquery/core/services/ingestion/readers/excel_reader.py @@ -138,14 +138,11 @@ def _as_number(cell: Any) -> float | None: return None numbers = [_as_number(c) for c in non_empty] - if any(num is None for num in numbers): - # Some populated cell is a non-numeric string -> not a spacer. - return False - - # Every populated cell is numeric and none is a text label. This covers - # both the all-numeric case and, as a strict subset, the contiguous - # ``1..k`` column-numbering run -- both are spacers, never label rows. - return True + # A spacer row has EVERY populated cell numeric and no text label. This + # covers both the all-numeric case and, as a strict subset, the contiguous + # ``1..k`` column-numbering run -- both are spacers, never label rows. Any + # non-numeric (None) populated cell means it is not a spacer. + return all(num is not None for num in numbers) class ExcelReader: diff --git a/src/flyquery/core/services/ingestion/stages/profile.py b/src/flyquery/core/services/ingestion/stages/profile.py index 301c724..4938650 100644 --- a/src/flyquery/core/services/ingestion/stages/profile.py +++ b/src/flyquery/core/services/ingestion/stages/profile.py @@ -337,9 +337,7 @@ def _profile_column_sync( # to the subtotal drops detail). No measure join is available # here, so this stays purely structural/textual. if not _is_numeric(data_type) and not _is_temporal(data_type): - subtotal_values = [ - str(r[0]) for r in top_rows if _looks_like_subtotal(str(r[0])) - ] + subtotal_values = [str(r[0]) for r in top_rows if _looks_like_subtotal(str(r[0]))] if subtotal_values: profile["subtotal_values"] = subtotal_values diff --git a/src/flyquery/core/services/retrieval/search_index.py b/src/flyquery/core/services/retrieval/search_index.py index 836eed0..beb828b 100644 --- a/src/flyquery/core/services/retrieval/search_index.py +++ b/src/flyquery/core/services/retrieval/search_index.py @@ -131,7 +131,7 @@ def value_fingerprint( return ref_note.strip(" |") -def _column_hit(r, score: float) -> "Hit": +def _column_hit(r, score: float) -> Hit: """Build a ranked schema-object Hit, enriched with a value fingerprint. Used by the BM25 + vector column searches so the ranked "Top-ranked diff --git a/src/flyquery/web/controllers/query_controller.py b/src/flyquery/web/controllers/query_controller.py index 0ddac2b..532cfbf 100644 --- a/src/flyquery/web/controllers/query_controller.py +++ b/src/flyquery/web/controllers/query_controller.py @@ -389,9 +389,7 @@ async def explain( grounding_agent = build_grounding_agent(self._settings) grounded_run = await grounding_agent.run( - _render_grounding_prompt( - question=body.question, bundle=bundle, starting_point_sql=None - ) + _render_grounding_prompt(question=body.question, bundle=bundle, starting_point_sql=None) ) grounded = getattr(grounded_run, "output", grounded_run) @@ -458,9 +456,7 @@ async def validate( grounding_agent = build_grounding_agent(self._settings) grounded_run = await grounding_agent.run( - _render_grounding_prompt( - question=body.question, bundle=bundle, starting_point_sql=None - ) + _render_grounding_prompt(question=body.question, bundle=bundle, starting_point_sql=None) ) grounded = getattr(grounded_run, "output", grounded_run) @@ -576,9 +572,7 @@ async def _stream_events( grounding_agent = build_grounding_agent(self._settings) grounded_run = await grounding_agent.run( - _render_grounding_prompt( - question=request.question, bundle=bundle, starting_point_sql=None - ) + _render_grounding_prompt(question=request.question, bundle=bundle, starting_point_sql=None) ) grounded = getattr(grounded_run, "output", grounded_run) @@ -635,7 +629,7 @@ async def _stream_events( # isolation or run SQL against a non-existent/cross-dataset table. ast = self._ast_classifier.classify(chosen_sql) table_resolver = TableResolver(session=db_session, settings=self._settings) - table_kinds = await _table_kinds_by_name(db_session, list(ast.table_refs), request.dataset_id) + table_kinds = await table_resolver.table_kinds_by_name(request.dataset_id, list(ast.table_refs)) exec_result = await self._guarded_execute( chosen_sql, ast, table_kinds, table_resolver, request.dataset_id, bundle ) @@ -655,8 +649,8 @@ async def _stream_events( refined = getattr(refined_run, "output", refined_run) chosen_sql = refined.sql ast = self._ast_classifier.classify(chosen_sql) - table_kinds = await _table_kinds_by_name( - db_session, list(ast.table_refs), request.dataset_id + table_kinds = await table_resolver.table_kinds_by_name( + request.dataset_id, list(ast.table_refs) ) exec_result = await self._guarded_execute( chosen_sql, ast, table_kinds, table_resolver, request.dataset_id, bundle @@ -790,26 +784,6 @@ async def _stream_events( yield _sse_frame("final", answer.model_dump(mode="json")) -async def _table_kinds_by_name( - session: AsyncSession, table_names: list[str], dataset_id: uuid.UUID -) -> dict[str, str]: - """Return {name: kind} for the active tables in the dataset (firewall input).""" - if not table_names: - return {} - import sqlalchemy as sa - - rows = await session.execute( - sa.text( - """ - SELECT name, kind FROM flyquery_tables - WHERE dataset_id = :ds AND name = ANY(:names) AND is_active = true - """ - ), - {"ds": dataset_id, "names": list(table_names)}, - ) - return {r["name"]: r["kind"] for r in rows.mappings()} - - def _now_ms() -> int: """Return current monotonic time in milliseconds.""" import time diff --git a/tests/integration/test_hybrid_retrieval.py b/tests/integration/test_hybrid_retrieval.py index dc9abd3..54bbdb6 100644 --- a/tests/integration/test_hybrid_retrieval.py +++ b/tests/integration/test_hybrid_retrieval.py @@ -78,6 +78,14 @@ async def _seed_table_with_column( ), {"id": snap_id, "t": tenant, "ws": ws_id, "ds": ds_id, "tbl": tbl_id, "hash": "testhash"}, ) + # Publish: point the table at this snapshot. Retrieval is scoped to + # ``current_snapshot_id`` (so re-ingests don't return stale/duplicate + # columns), so a seeded table must have its current snapshot set -- exactly + # as the publish stage does in real ingestion. + await s.execute( + sa.text("UPDATE flyquery_tables SET current_snapshot_id = :snap WHERE id = :tbl"), + {"snap": snap_id, "tbl": tbl_id}, + ) if embedding is not None: vec = str(embedding) await s.execute( diff --git a/tests/unit/test_query_controller_sse.py b/tests/unit/test_query_controller_sse.py index 9f137c0..c786e09 100644 --- a/tests/unit/test_query_controller_sse.py +++ b/tests/unit/test_query_controller_sse.py @@ -89,7 +89,13 @@ class _FakeTableResolver: def __init__(self): self._session = _FakeSession() - async def resolve(self, dataset_id, table_names, object_store_base=None): + async def resolve(self, dataset_id, table_names, object_store_base=None, pins=None): + return {} + + async def table_kinds_by_name(self, dataset_id, table_names): + return {} + + async def current_snapshots(self, dataset_id, table_names): return {} diff --git a/tests/unit/test_query_service.py b/tests/unit/test_query_service.py index 94667aa..6cd6167 100644 --- a/tests/unit/test_query_service.py +++ b/tests/unit/test_query_service.py @@ -88,7 +88,10 @@ class _FakeTableResolver: def __init__(self): self._session = _FakeSession() - async def resolve(self, dataset_id, table_names, object_store_base=None): + async def resolve(self, dataset_id, table_names, object_store_base=None, pins=None): + return {} + + async def current_snapshots(self, dataset_id, table_names): return {} From 7847a81770a3cbccd60405868c03e1738225d949 Mon Sep 17 00:00:00 2001 From: Jose Agustin Puente Date: Tue, 2 Jun 2026 13:06:03 +0200 Subject: [PATCH 3/4] feat(ingestion): reconstruct multi-row period headers in XLSX sections Financial exports (Orbis/BvD) place one date-header row above a stack of sub-sections. Section-splitting orphaned that header above each sub-section's title, dropping the year columns and making the detailed financial tables (Ratios de Rentabilidad, Memo lineas, etc.) unqueryable by period. A data-first section now inherits the nearest recent period header as its column row, encoded as a backward-compatible 3-index section path ([header:data_start:end]); contiguous sections keep the legacy 2-index form. Guards against misfires: - period values are full-match only (INV-2024-0007 / Form 2020 are not years) - only a single leftmost label column may precede the inherited years - a label-headed table closes the period band (no stale-header bleed) - phantom null columns are compacted; header-row index is bounds-checked Adds unit coverage for inheritance, band-closing/no-false-inherit, path round-trip and legacy-path parsing. --- .../ingestion/readers/excel_reader.py | 189 +++++++++++++++--- .../test_excel_period_header_inheritance.py | 115 +++++++++++ 2 files changed, 271 insertions(+), 33 deletions(-) create mode 100644 tests/unit/test_excel_period_header_inheritance.py diff --git a/src/flyquery/core/services/ingestion/readers/excel_reader.py b/src/flyquery/core/services/ingestion/readers/excel_reader.py index 91c943c..d2fca8f 100644 --- a/src/flyquery/core/services/ingestion/readers/excel_reader.py +++ b/src/flyquery/core/services/ingestion/readers/excel_reader.py @@ -55,6 +55,7 @@ from __future__ import annotations import asyncio +import datetime import re import tempfile from pathlib import Path @@ -68,7 +69,18 @@ ) _NAME_SAFE_RE = re.compile(r"[^A-Za-z0-9_]+") -_SECTION_RE = re.compile(r"^(?P.+)#section\[(?P\d+):(?P\d+)\]$") +# ``#section[
:]`` (header row + data on the next rows, contiguous) +# OR ``#section[
::]`` when the header row is NOT +# contiguous with the data (an inherited period header -- see below). +_SECTION_RE = re.compile(r"^(?P.+)#section\[(?P\d+):(?P\d+)(?::(?P\d+))?\]$") +# A period *value* must be the WHOLE cell (a bare year, an ISO date with an +# optional time, or an FY/H/Q-prefixed year) -- NOT merely a string that +# happens to contain a 4-digit year (which would match invoice/reference codes +# like ``INV-2024-0007`` or ``Form 2020`` and wrongly flag a row as a header). +_PERIOD_RE = re.compile( + r"^(?:fy|h[12]|q[1-4])?[\s/-]?(?:19|20)\d{2}(?:-\d{2}-\d{2})?(?:[ t]\d{2}:\d{2}(?::\d{2})?)?$", + re.IGNORECASE, +) # Heuristic constants _MIN_HEADER_NON_EMPTY = 2 # a row needs >= 2 non-empty cells to be a header @@ -76,6 +88,9 @@ # When the first header candidate looks like a numeric/spacer row, look this # many rows ahead for a clearly-better string-label row before settling. _HEADER_LOOKAHEAD = 3 +# Max rows a data-first section may look BACK to inherit a period/date header +# (financial reports repeat a date header above each block of sub-sections). +_PERIOD_HEADER_MAX_DISTANCE = 60 def _normalise_leading_blanks(rows: list[list[Any]]) -> list[list[Any]]: @@ -145,6 +160,37 @@ def _as_number(cell: Any) -> float | None: return all(num is not None for num in numbers) +def _is_period_value(cell: Any) -> bool: + """True when a cell IS an accounting period / date / year label. + + Full-match (not substring) so reference/invoice codes that merely embed a + year (``INV-2024-0007``, ``Form 2020``) are not misread as period headers. + """ + if isinstance(cell, (datetime.date, datetime.datetime)): + return True + if isinstance(cell, str): + return bool(_PERIOD_RE.match(cell.strip())) + return False + + +def _is_period_header_row(row: list[Any]) -> bool: + """True when a row is a period/date header (the bulk of its cells are years/dates). + + Financial-report exports place a single date header (``2024-12-31 ...``) + above a block of sub-sections; we use this to let a following data-first + section inherit those column labels instead of guessing names. + """ + non_empty = [c for c in row if c not in ("", None)] + if len(non_empty) < _MIN_HEADER_NON_EMPTY: + return False + period = sum(1 for c in non_empty if _is_period_value(c)) + return period >= 2 and period * 5 >= len(non_empty) * 3 # >= 60% period-like + + +def _populated_cols(row: list[Any]) -> set[int]: + return {c for c, v in enumerate(row) if v not in ("", None)} + + class ExcelReader: formats = ("xlsx", "xls", "ods") @@ -200,6 +246,7 @@ def _extract_sections(rows: list[list[Any]]) -> list[dict[str, Any]]: sections: list[dict[str, Any]] = [] i = 0 pending_label: str | None = None + last_period_header_idx: int | None = None n = len(rows) while i < n: row = rows[i] @@ -229,7 +276,7 @@ def _extract_sections(rows: list[list[Any]]) -> list[dict[str, Any]]: # a better row exists -- genuinely numeric / period / date headers # (no string-label row just below) are left untouched, as are # single-row sheets. - header_idx = i + own_header_idx = i if _is_numeric_spacer_row(row) and not _looks_like_label_row(row): look_end = min(i + 1 + _HEADER_LOOKAHEAD, n) for la in range(i + 1, look_end): @@ -241,9 +288,45 @@ def _extract_sections(rows: list[list[Any]]) -> list[dict[str, Any]]: if len(la_ne) >= _MIN_HEADER_NON_EMPTY and _looks_like_label_row(rows[la]): # Found a better string-label header just below; treat # the skipped numeric/title rows as pre-header. - header_idx = la + own_header_idx = la break - j = header_idx + 1 + + # Period-header inheritance. Financial-report exports (Orbis/BvD, + # etc.) place ONE date header (``2024-12-31 2023-12-31 ...``) above + # a block of sub-sections (P&L, ratios, ...), each introduced by its + # own title. Section-splitting starts each sub-section at its first + # DATA row, orphaning that shared header above the title -- so the + # value columns get opaque/guessed names instead of the years. When + # a section starts directly with data (its own header row is neither + # label-like nor a period header) and a recent period header covers + # its value columns, adopt that period header as this section's + # column header and treat the section's own first row as data. + own_is_period = _is_period_header_row(rows[own_header_idx]) + own_is_label = _looks_like_label_row(rows[own_header_idx]) + header_row_idx = own_header_idx + data_start = own_header_idx + 1 + inherited = False + if ( + last_period_header_idx is not None + and not own_is_label + and not own_is_period + and own_header_idx - last_period_header_idx <= _PERIOD_HEADER_MAX_DISTANCE + ): + sec_cols = _populated_cols(rows[own_header_idx]) + ph_cols = _populated_cols(rows[last_period_header_idx]) + extra = sec_cols - ph_cols + # The period header must cover the section's value columns. The + # ONLY column it may legitimately not cover is a row-label column + # to the LEFT of the period columns -- never a trailing value + # column (that would shift the inherited year labels by one). + if len(ph_cols & sec_cols) >= _MIN_HEADER_NON_EMPTY and ( + not extra or (len(extra) == 1 and min(extra) < min(ph_cols)) + ): + header_row_idx = last_period_header_idx + data_start = own_header_idx + inherited = True + + j = data_start consecutive_empty = 0 data_end = j # exclusive while j < n: @@ -262,22 +345,30 @@ def _extract_sections(rows: list[list[Any]]) -> list[dict[str, Any]]: data_end = j + 1 j += 1 - n_data_rows = data_end - (header_idx + 1) + # Maintain the active period header. A genuine period/date header + # opens (or renews) a band that the following data-first sub-sections + # inherit; a real label-headed table CLOSES the band so a stale date + # header can't bleed into an unrelated (positionally-overlapping) + # table further down. + if own_is_period: + last_period_header_idx = own_header_idx + elif own_is_label: + last_period_header_idx = None + + n_data_rows = data_end - data_start if n_data_rows >= 1: - # Compute the union of populated column indices across - # the entire section. This is the real column count after - # compaction (we drop empty/merged-padding columns at - # materialise time). - populated_cols: set[int] = set() - for k in range(header_idx, data_end): - for col_idx, cell in enumerate(rows[k]): - if cell not in ("", None): - populated_cols.add(col_idx) + # Populated columns. For a contiguous section the header row is + # part of the table, so include it. For an INHERITED header we + # count only DATA columns, so a period the sub-section does not + # report does not become an all-NULL column. + populated_cols: set[int] = set() if inherited else set(_populated_cols(rows[header_row_idx])) + for k in range(data_start, data_end): + populated_cols |= _populated_cols(rows[k]) sections.append( { "label": pending_label or f"section_{len(sections):02d}", - "header_row_idx": header_idx, - "data_start_idx": header_idx + 1, + "header_row_idx": header_row_idx, + "data_start_idx": data_start, "data_end_idx": data_end, "n_cols": len(populated_cols), "n_data_rows": n_data_rows, @@ -328,9 +419,7 @@ def _enumerate_sync(source_path: str, rules: TableExtractionRules) -> list[Propo out.append( ProposedTable( name=ExcelReader._sanitise(sheet_name), - sheet_or_json_path=( - f"{sheet_name}#section[{s['header_row_idx']}:{s['data_end_idx']}]" - ), + sheet_or_json_path=ExcelReader._section_path(sheet_name, s), n_columns=s["n_cols"], n_rows_estimate=s["n_data_rows"], ) @@ -350,9 +439,7 @@ def _enumerate_sync(source_path: str, rules: TableExtractionRules) -> list[Propo out.append( ProposedTable( name=final_name, - sheet_or_json_path=( - f"{sheet_name}#section[{s['header_row_idx']}:{s['data_end_idx']}]" - ), + sheet_or_json_path=ExcelReader._section_path(sheet_name, s), n_columns=s["n_cols"], n_rows_estimate=s["n_data_rows"], ) @@ -360,16 +447,30 @@ def _enumerate_sync(source_path: str, rules: TableExtractionRules) -> list[Propo return out @staticmethod - def _parse_section_path(path: str) -> tuple[str, int | None, int | None]: - """Split ``#section[:]`` -> (sheet, start, end). + def _section_path(sheet_name: str, s: dict[str, Any]) -> str: + """Encode a section span. ``[h:e]`` when header+data are contiguous, + ``[h:ds:e]`` when the header row is inherited (not adjacent to data).""" + h, ds, e = s["header_row_idx"], s["data_start_idx"], s["data_end_idx"] + return f"{sheet_name}#section[{h}:{e}]" if ds == h + 1 else f"{sheet_name}#section[{h}:{ds}:{e}]" - For backward compat, a plain sheet name (no ``#section[...]``) - returns ``(sheet, None, None)``. + @staticmethod + def _parse_section_path(path: str) -> tuple[str, int | None, int | None, int | None]: + """Split a section path -> ``(sheet, header_idx, data_start, data_end)``. + + Accepts both ``#section[
:]`` (contiguous; data starts + at ``header+1``) and ``#section[
::]`` (an + inherited period header that is NOT adjacent to its data). For backward + compat, a plain sheet name (no ``#section[...]``) returns all ``None``. """ m = _SECTION_RE.match(path or "") if not m: - return path or "", None, None - return m.group("sheet"), int(m.group("start")), int(m.group("end")) + return path or "", None, None, None + header_idx = int(m.group("a")) + if m.group("c") is not None: + data_start, data_end = int(m.group("b")), int(m.group("c")) + else: + data_start, data_end = header_idx + 1, int(m.group("b")) + return m.group("sheet"), header_idx, data_start, data_end @staticmethod def _materialise_sync( @@ -388,7 +489,7 @@ def _materialise_sync( from python_calamine import CalamineWorkbook # pyright: ignore[reportMissingImports] Path(target_parquet_key).parent.mkdir(parents=True, exist_ok=True) - sheet_name, sec_start, sec_end = ExcelReader._parse_section_path( + sheet_name, header_idx, data_start, data_end = ExcelReader._parse_section_path( table.sheet_or_json_path or table.name ) wb = CalamineWorkbook.from_path(source_path) @@ -397,9 +498,27 @@ def _materialise_sync( # the stored section indices slice the intended rows (HDR-UNSTABLE). rows = _normalise_leading_blanks(sheet.to_python(skip_empty_area=True)) - if sec_start is not None and sec_end is not None: - # Section-encoded path -- slice precisely. - section_rows = rows[sec_start:sec_end] + inherited_header = False + if header_idx is not None and data_end is not None: + # Section-encoded path. The header row + the data rows, which may be + # NON-contiguous when the header was inherited from a period header + # above the section's title (financial-report layout). For the + # common contiguous case (data_start == header_idx + 1) this is + # exactly ``rows[header_idx:data_end]``. + # + # ``rows[header_idx]`` is a scalar index, so guard it: if the sheet + # changed between enumerate and materialise (re-ingest drift -- the + # raison d'être of HDR-UNSTABLE) the index can fall past the end. The + # old slice degraded silently; we raise a contextual ValueError + # instead of an opaque IndexError. + if header_idx >= len(rows): + raise ValueError( + f"sheet {sheet_name!r}: header row {header_idx} out of range " + f"(rows={len(rows)}) for table {table.name!r} " + f"(path={table.sheet_or_json_path!r}); sheet changed since enumerate?" + ) + inherited_header = data_start != header_idx + 1 + section_rows = [rows[header_idx]] + rows[data_start:data_end] else: # Legacy / no-section path -- apply the merged-cell title heuristic. body_start = 0 @@ -429,8 +548,12 @@ def _materialise_sync( # Compacting to ONLY the populated indices yields rows that # match the visual "5 yearly columns + label" view a human # sees in Excel. + # For an inherited (non-contiguous) header, compact over the DATA rows + # only -- a period the sub-section doesn't report must not survive as an + # all-NULL year column just because the shared header names it. + compact_rows = section_rows[1:] if inherited_header and len(section_rows) > 1 else section_rows populated: set[int] = set() - for r in section_rows: + for r in compact_rows: for col_idx, cell in enumerate(r): if cell not in ("", None): populated.add(col_idx) diff --git a/tests/unit/test_excel_period_header_inheritance.py b/tests/unit/test_excel_period_header_inheritance.py new file mode 100644 index 0000000..6e7fc34 --- /dev/null +++ b/tests/unit/test_excel_period_header_inheritance.py @@ -0,0 +1,115 @@ +# Copyright 2024-2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for period-header inheritance in the XLSX section extractor. + +Financial-report exports (Orbis/BvD) place ONE date header above a block of +sub-sections; section-splitting orphans that header above each sub-section's +title. These tests pin the inheritance behavior + its guards (no false +inherit, leftmost-label-only, legacy path round-trip). +""" + +from __future__ import annotations + +import datetime + +from flyquery.core.services.ingestion.readers.excel_reader import ( + ExcelReader, + _is_period_header_row, + _is_period_value, +) + + +def test_is_period_value_full_match_rejects_embedded_year_codes() -> None: + for ok in ("2024", "2024-12-31", "FY2024", "Q1 2024", "H1 2023"): + assert _is_period_value(ok), ok + assert _is_period_value(datetime.datetime(2024, 12, 31)) + # Codes / addresses that merely *contain* a year must NOT be period values. + for bad in ("INV-2024-0007", "Form 2020", "2001 Main St", "REG-2019-X", "86.21", "Andalucia"): + assert not _is_period_value(bad), bad + + +def test_is_period_header_row_requires_majority_period_cells() -> None: + assert _is_period_header_row(["", "2024", "2023", "2022"]) + assert _is_period_header_row([datetime.date(2024, 12, 31), datetime.date(2023, 12, 31)]) + # A single year among labels is not a header. + assert not _is_period_header_row(["Founded 2024", "CEO", "Revenue"]) + + +def _orbis_like_rows() -> list[list[object]]: + """A miniature Orbis-style sheet: + + one date header governs a data-first sub-section introduced by its own + title; then a label-headed table closes the band; then another data-first + section that must NOT inherit the (now stale) date header. + """ + return [ + ["Financial data", "", "", ""], # 0 title + ["", "2024", "2023", "2022"], # 1 period header (cols 1-3) + ["Profit & Loss", "", "", ""], # 2 title + ["Revenue", 100, 90, 80], # 3 data-first -> inherits row 1 + ["Costs", 40, 30, 20], # 4 data + ["", "", "", ""], # 5 blank + ["", "", "", ""], # 6 blank (section break) + ["Board", "", "", ""], # 7 title + ["Name", "Role", "", ""], # 8 label header (closes the period band) + ["Alice", "CEO", "", ""], # 9 data + ["", "", "", ""], # 10 blank + ["", "", "", ""], # 11 blank + ["Extra metrics", "", "", ""], # 12 title + ["Metric A", 1, 2, 3], # 13 data-first, columns OVERLAP the date header + ["Metric B", 4, 5, 6], # 14 data + ] + + +def test_data_first_section_inherits_period_header() -> None: + secs = ExcelReader._extract_sections(_orbis_like_rows()) + pnl = next(s for s in secs if s["label"] == "Profit & Loss") + # Header is the date row (1), data is non-contiguous (starts at 3). + assert pnl["header_row_idx"] == 1 + assert pnl["data_start_idx"] == 3 + assert pnl["data_end_idx"] == 5 + + +def test_label_headed_table_does_not_inherit_and_closes_the_band() -> None: + secs = ExcelReader._extract_sections(_orbis_like_rows()) + board = next(s for s in secs if s["label"] == "Board") + # Board has its OWN label header -> contiguous, no inheritance. + assert board["data_start_idx"] == board["header_row_idx"] + 1 + + # The later "Extra metrics" section shares column positions with the date + # header, but the band was CLOSED by the Board table -> it must NOT inherit. + extra = next(s for s in secs if s["label"] == "Extra metrics") + assert extra["data_start_idx"] == extra["header_row_idx"] + 1 + + +def test_section_path_round_trip_contiguous_and_inherited() -> None: + # Contiguous -> 2-index form (byte-identical to legacy). + s_contig = {"header_row_idx": 5, "data_start_idx": 6, "data_end_idx": 9} + p = ExcelReader._section_path("Sheet1", s_contig) + assert p == "Sheet1#section[5:9]" + assert ExcelReader._parse_section_path(p) == ("Sheet1", 5, 6, 9) + + # Inherited (non-contiguous) -> 3-index form. + s_inh = {"header_row_idx": 1, "data_start_idx": 3, "data_end_idx": 5} + p2 = ExcelReader._section_path("Sheet1", s_inh) + assert p2 == "Sheet1#section[1:3:5]" + assert ExcelReader._parse_section_path(p2) == ("Sheet1", 1, 3, 5) + + +def test_parse_section_path_legacy_and_plain() -> None: + # Legacy 2-index path still parses (already-stored tables). + assert ExcelReader._parse_section_path("S#section[2:7]") == ("S", 2, 3, 7) + # Plain sheet name -> all None. + assert ExcelReader._parse_section_path("JustASheet") == ("JustASheet", None, None, None) From c135bcc35a5340025ebac05022d34a7163e1099c Mon Sep 17 00:00:00 2001 From: Jose Agustin Puente Date: Tue, 30 Jun 2026 13:08:07 +0200 Subject: [PATCH 4/4] =?UTF-8?q?feat(query+ingest):=20unify=20value-anchori?= =?UTF-8?q?ng=20superset=20onto=20the=20NL=E2=86=92SQL=20branch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reconciles the two parallel lines of the NL→SQL quality work into a single branch that carries every refinement from both, with no capability lost. Kept from this branch (execution-layer robustness): - TableResolver / AstClassifier / profile hardening and the query_controller surface, plus their unit/integration tests. - The XLSX multi-row period-header reconstruction (embedded-year rejection, band-closing, 3-index section paths) and its test. - PII-safe value rendering in the embed corpus. Brought in (the value-anchoring superset): - value_anchoring.py: the dataset-agnostic backbone (per-column value catalogue, value→column + group resolution, entity disambiguation by live row-match count + hierarchy intent, degenerate/zero-row detection, signed (mixed-sign) measure handling incl. the SUM(CASE WHEN cost THEN -m) shape, CTE-aware guards, function firewall). - query_service orchestration: candidate execution-selection, a non-destructive repair loop (never replaces a correct answer with a worse one), zero-row and signed-measure repair, group-coverage advisory. - Deterministic period/date column naming at ingest (year_YYYY / period_YYYY_MM_DD), original-header preservation, and the value-anchoring prompt rules for grounding / generation / critic / explainer. - LexicalReranker as the dependency-free default (a strict win over the no-op). - temperature=0 for the deterministic ingest agents (describe / column-name / rename / relation), applied via the existing extra_settings hook. Cross-line graft: the column hierarchy hint (references_column, computed by the profile stage) is now surfaced in the retrieval column fingerprint, so a self-referencing manager/owner column tells the agent to filter THAT column to find a person's reports instead of their own row. Gate notes: - builder.py stays byte-identical to canon (lockstep) -- temperature=0 rides the existing extra_settings path, no new parameter. - ruff check + format clean; lockstep 21/21. - test_query_service: the _FakeSettings stub gains the value-anchoring knobs (features off, so the orchestration tests stay focused on the core answer() flow), and the retry-on-error test now uses a realistic non-degenerate refined SQL with a table-aware resolver so the critic round is actually exercised. --- src/flyquery/config.py | 17 + .../core/agents/relation_proposer_agent.py | 2 + .../core/services/ingestion/reader.py | 5 + .../core/services/ingestion/stages/parse.py | 67 + .../services/ingestion/stages/reconcile.py | 48 +- .../core/services/query/query_service.py | 1168 +++++++++++------ .../core/services/query/value_anchoring.py | 795 +++++++++++ .../services/retrieval/hybrid_retriever.py | 6 + .../core/services/retrieval/reranker.py | 67 +- .../core/services/retrieval/search_index.py | 276 ++-- src/flyquery/resources/prompts/critic.yaml | 30 +- src/flyquery/resources/prompts/explainer.yaml | 14 +- .../resources/prompts/generation.yaml | 34 + src/flyquery/resources/prompts/grounding.yaml | 29 +- tests/integration/test_reranker.py | 16 +- tests/unit/test_query_service.py | 63 +- 16 files changed, 2044 insertions(+), 593 deletions(-) create mode 100644 src/flyquery/core/services/query/value_anchoring.py diff --git a/src/flyquery/config.py b/src/flyquery/config.py index e8b27ad..ab3cc00 100644 --- a/src/flyquery/config.py +++ b/src/flyquery/config.py @@ -181,6 +181,23 @@ class FlyquerySettings(BaseSettings): reranker_top_n: int = 30 query_expansion_enabled: bool = False + # Value-anchoring / grounding-quality knobs. The ingest pipeline already + # computes a per-column value catalogue (profile_json.top_values, min/max, + # distinct_estimate) + a semantic_type; these control how it is surfaced to + # the SQL writer at query time. All dataset-agnostic. + value_catalog_enabled: bool = True # render real column values into prompts + value_catalog_max_values: int = 25 # distinct values shown per column + value_catalog_char_budget: int = 320 # char cap on the value list per column + value_catalog_max_columns: int = 80 # cap columns that get a value line (prompt budget) + entity_resolution_enabled: bool = True # map question literals -> owning column + entity_resolution_max_literals: int = 8 # cap live value-scan probes per query + zero_row_repair_enabled: bool = True # repair queries that run but return 0 rows + candidate_exec_selection: bool = True # execute top candidates, pick a non-empty/non-degenerate one + synthesis_function_firewall: bool = True # block read_csv_auto/pg_read_file/... in generated SQL + group_resolution_enabled: bool = True # term -> full set of catalogued values it umbrellas + signed_measure_repair_enabled: bool = True # observed-sign probe on subtractions over a signed measure + group_coverage_repair_enabled: bool = True # advise when an IN-list under-covers a detected value group + # PII pii_scanner: Literal["regex", "presidio", "disabled"] = "regex" pii_policy_samples: Literal["warn", "redact", "reject"] = "redact" diff --git a/src/flyquery/core/agents/relation_proposer_agent.py b/src/flyquery/core/agents/relation_proposer_agent.py index 769df77..7405c03 100644 --- a/src/flyquery/core/agents/relation_proposer_agent.py +++ b/src/flyquery/core/agents/relation_proposer_agent.py @@ -56,4 +56,6 @@ def build_relation_proposer_agent(settings): output_type=ProposedRelations, instructions=prompt.instructions, settings=settings, + # Deterministic relation proposals across re-ingests. + extra_settings={"temperature": 0.0}, ) diff --git a/src/flyquery/core/services/ingestion/reader.py b/src/flyquery/core/services/ingestion/reader.py index f631d0b..0f213dd 100644 --- a/src/flyquery/core/services/ingestion/reader.py +++ b/src/flyquery/core/services/ingestion/reader.py @@ -51,6 +51,11 @@ class ColumnSchema: data_type: str is_nullable: bool position: int + # The source's ORIGINAL header before any rename (e.g. an Excel year header + # '2024' that the column-name proposer collapsed to ``year_1``). Preserved so + # the query layer can recover what a renamed column actually meant, instead of + # relying on a fixed ordinal convention. None when the name was not renamed. + original_name: str | None = None @dataclass(frozen=True) diff --git a/src/flyquery/core/services/ingestion/stages/parse.py b/src/flyquery/core/services/ingestion/stages/parse.py index fea6eba..362b51c 100644 --- a/src/flyquery/core/services/ingestion/stages/parse.py +++ b/src/flyquery/core/services/ingestion/stages/parse.py @@ -56,6 +56,52 @@ def _sanitise_name(name: str) -> str: return _SAFE_NAME_RE.sub("_", name).strip("_") or "table" +def _period_name_from_header(header: str | None) -> str | None: + """Deterministic column name for a header that IS a calendar period. + + ``'2024'`` / ``'FY2024'`` / ``'2024f'`` -> ``year_2024``; + ``'2024-12-31'`` / ``'31/12/2024'`` -> ``period_2024_12_31``. + Returns None for anything that is not clearly a year/date header (e.g. + ``'Operating revenue'``), so only genuine period columns are stabilised. + Generic — recognises the period from the source's own header, no domain rule. + """ + if not header: + return None + s = str(header).strip() + if not s or len(s) > 40: + return None + for pat, idx in ( + (r"^(\d{4})[-/.](\d{1,2})[-/.](\d{1,2})$", (1, 2, 3)), + (r"^(\d{1,2})[-/.](\d{1,2})[-/.](\d{4})$", (3, 2, 1)), + ): + m = re.match(pat, s) + if m: + y, mo, d = int(m.group(idx[0])), int(m.group(idx[1])), int(m.group(idx[2])) + if 1900 <= y <= 2100 and 1 <= mo <= 12 and 1 <= d <= 31: + return f"period_{y}_{mo:02d}_{d:02d}" + m = re.match(r"(?i)^(?:fy[\s_]*)?(\d{4})f?$", s) + if m: + y = int(m.group(1)) + if 1900 <= y <= 2100: + return f"year_{y}" + return None + + +def _apply_deterministic_period_names(originals: list[str], proposed: list[str]) -> list[str]: + """Override the LLM proposer for columns whose ORIGINAL header is a year/date, + so period columns are named stably + correctly across re-ingests (round-2).""" + out = list(proposed) + changed = False + for i, orig in enumerate(originals): + nm = _period_name_from_header(orig) + if nm and out[i] != nm: + out[i] = nm + changed = True + if changed: + out = _dedupe_names(out, fallback_prefix="period") + return out + + def _sanitise_proposed_name(name: str) -> str: """Best-effort enforcement of snake_case identifier rules on agent output.""" cleaned = _SAFE_NAME_RE.sub("_", name).strip("_").lower() @@ -161,6 +207,22 @@ async def _propose_meaningful_column_names( at least scoped to the section. """ current_names = [c.name for c in mat_result.columns] + + # Deterministically clean year/date headers FIRST -- before the + # needs_proposal gate, which skips columns that already look "meaningful" + # (a date column does). This turns '2024-12-31' -> period_2024_12_31 and + # '2024' -> year_2024 stably across re-ingests, and renames the Parquet so + # the stored column matches. + period_named = _apply_deterministic_period_names(current_names, list(current_names)) + if period_named != current_names: + await _rename_parquet_columns( + parquet_path=parquet_path, + current_columns=current_names, + proposed_columns=period_named, + ) + mat_result = _rebuild_mat_result(mat_result, period_named) + current_names = period_named + if not needs_proposal(current_names): return mat_result @@ -224,6 +286,9 @@ async def _propose_meaningful_column_names( ) proposed = fallback + # Safety net: clean any remaining year/date names the proposer produced. + proposed = _apply_deterministic_period_names(current_names, proposed) + if proposed == current_names: return mat_result @@ -256,6 +321,8 @@ def _rebuild_mat_result( data_type=c.data_type, is_nullable=c.is_nullable, position=c.position, + # remember the original header iff it actually changed + original_name=(c.name if new_names[i] != c.name else None), ) for i, c in enumerate(mat_result.columns) ), diff --git a/src/flyquery/core/services/ingestion/stages/reconcile.py b/src/flyquery/core/services/ingestion/stages/reconcile.py index 54dd9a3..def9e4f 100644 --- a/src/flyquery/core/services/ingestion/stages/reconcile.py +++ b/src/flyquery/core/services/ingestion/stages/reconcile.py @@ -27,6 +27,7 @@ import hashlib import json import logging +import re import uuid from dataclasses import dataclass from typing import Any, cast @@ -40,6 +41,47 @@ normalize_synonyms_json, ) +_SYNTHETIC_HEADER = re.compile(r"(?i)^(column|col|unnamed|field|sheet)[\s_:.\-]*\d*$") + + +def _plausible_original_header(name: str | None) -> str | None: + """Return a real source header worth preserving, else None. + + Keeps meaningful headers (incl. plain years like '2024') but drops synthetic + placeholders (column5, Unnamed: 0) and obvious null tokens, so the surfaced + 'original header' is signal, not noise. Fully generic. + """ + if not name: + return None + s = str(name).strip() + if not s or len(s) > 60: + return None + if _SYNTHETIC_HEADER.match(s): + return None + if s.lower() in {"none", "null", "nan", "n/a", "na", "-"}: + return None + # A header that is just a number is almost certainly a mis-detected data cell; + # skip it -- UNLESS it is a plausible 4-digit year (a real period header). + try: + float(s.replace(",", "")) + if not (s.isdigit() and len(s) == 4 and 1900 <= int(s) <= 2100): + return None + except ValueError: + pass # non-numeric header -> keep + return s + + +def _governance_with_header(annotation: dict, col: Any) -> str: + """Serialise governance_json, merging the column's plausible original header.""" + gov = normalize_governance_json(annotation.get("governance_json")) + if not isinstance(gov, dict): + gov = {} + oh = _plausible_original_header(getattr(col, "original_name", None)) + if oh and not gov.get("original_header"): + gov = {**gov, "original_header": oh} + return json.dumps(gov) + + # Rename auto-confirm threshold (passed from settings when available) _DEFAULT_AUTO_CONFIRM_THRESHOLD = 0.8 @@ -281,9 +323,9 @@ async def run_reconcile( "pii_tag": annotation.get("pii_tag"), "pii_source": annotation.get("pii_source"), "business_owner": annotation.get("business_owner"), - "governance_json": json.dumps( - normalize_governance_json(annotation.get("governance_json")) - ), + # round-2 #3: preserve the source's original header (pre-rename) + # so the query layer knows what e.g. year_1 actually was. + "governance_json": _governance_with_header(annotation, col), }, ) diff --git a/src/flyquery/core/services/query/query_service.py b/src/flyquery/core/services/query/query_service.py index ffbd8b4..99f5e6c 100644 --- a/src/flyquery/core/services/query/query_service.py +++ b/src/flyquery/core/services/query/query_service.py @@ -44,6 +44,7 @@ from flyquery.core.services.execution.ast_classifier import AstClassifier from flyquery.core.services.execution.duckdb_executor import ExecutionError, ExecutionResult from flyquery.core.services.execution.scope_guard import ScopeGuard, ScopeGuardError +from flyquery.core.services.query import value_anchoring from flyquery.core.services.semantic.compiler import SemanticCompiler logger = logging.getLogger(__name__) @@ -128,33 +129,6 @@ def _render_grounding_prompt( out.append(f"- `{qn}` :: {text}") out.append("") - # 2b. The "Column value catalogue" lists EVERY column with its real - # values (distinct set for low-cardinality columns; numeric - # range otherwise). This is the ground truth the agent must copy - # filter/CASE literals from -- it prevents guessing wrong - # literals (`Year IN (2023)` when the values are `FY23`), maps a - # question entity to the column whose values contain it - # (`DAPA` lives in a column's values, not a column name), and - # reveals tall/EAV layouts (P&L line items are VALUES of a single - # column) and scaled-duplicate measures (`FY` vs `FY (Real)`). - inv_columns = [h for h in inventory if (getattr(h, "metadata", {}) or {}).get("kind") == "COLUMN"] - cols_with_values = [h for h in inv_columns if (getattr(h, "metadata", {}) or {}).get("values")] - if cols_with_values: - out.append(f"# Column value catalogue ({len(cols_with_values)} columns)") - out.append( - "Real values per column. When the question names an entity (a brand, " - "year, market, category, P&L line, team…) that is NOT a column name, " - "find the column whose values contain it and filter THAT column. Copy " - "filter/CASE literals VERBATIM from these values (values may be encoded, " - "e.g. a year shown as `FY23`). If several columns share members, prefer " - "the one whose values match the question most precisely." - ) - for h in cols_with_values: - md = getattr(h, "metadata", None) or {} - qn = md.get("qualified_name") or "?" - out.append(f"- `{qn}` :: {md.get('values')}") - out.append("") - examples = bundle.get("examples", []) or [] if examples: out.append(f"# Approved Q→SQL examples ({len(examples)})") @@ -209,23 +183,47 @@ def _render_grounding_prompt( return "\n".join(out) +def _lookup_col(qn: str, cat_by_qn: dict) -> dict | None: + if qn in cat_by_qn: + return cat_by_qn[qn] + # tolerate grounding returning a shorter/unqualified column name + tail = qn.rsplit(".", 1)[-1].lower() + for k, v in cat_by_qn.items(): + if k.rsplit(".", 1)[-1].lower() == tail: + return v + return None + + def _render_generation_prompt( question: str, grounded: Any, starting_point_sql: str | None, *, schema_inventory: list[Any] | None = None, + col_catalog: list[dict] | None = None, + resolved_entities: list[dict] | None = None, + resolved_groups: list[dict] | None = None, + hierarchy_intent: bool = False, + examples: list[Any] | None = None, + n_candidates: int = 3, + max_columns: int = 80, + char_budget: int = 320, ) -> str: - """Pack the grounded context into a SQL-generation prompt. - - The ``schema_inventory`` fallback is appended even when grounding - returned a non-empty tables list -- the inventory acts as a - safety net the generation agent can fall back on if it judges - the grounded set incomplete (e.g. needs a join to a table the - grounding agent missed). Without this fallback, generation - hallucinates plausible-but-nonexistent tables like - ``balance_sheet`` whenever grounding under-selects. + """Pack the grounded context + the per-column VALUE CATALOGUE into a prompt. + + The generation prompt now shows, for each in-scope column, its data_type, + semantic role (measure/dimension/time), and the actual distinct VALUES (or + numeric range) the column holds -- so the SQL writer copies real filter + literals instead of inventing them. Question entities that were located in + the data are listed explicitly. All sourced from the dataset's own + ingest-time profiling; nothing is hardcoded. """ + from collections import defaultdict + + cat_by_qn = {c["qualified_name"]: c for c in (col_catalog or [])} + cols_by_table: dict[str, list[dict]] = defaultdict(list) + for c in col_catalog or []: + cols_by_table[c["qualified_name"].rsplit(".", 1)[0]].append(c) out: list[str] = [] out.append("# User question") @@ -239,15 +237,62 @@ def _render_generation_prompt( out.append("```") out.append("") + if resolved_entities: + out.append("# Resolved entities (literals located in the data)") + out.append( + "These exact question literals were FOUND as stored values. Filter on " + "the stated column with the stated value verbatim. When a literal was " + "found in MORE THAN ONE column, the number of rows it matches is shown:" + ) + for e in resolved_entities[:24]: + cnt = e.get("match_count") + tag = "" + if cnt is not None: + role = "repeats → grouping/parent key" if cnt > 1 else "appears once → identity" + tag = f" — matches {cnt} row(s), {role}" + out.append(f"- '{e['literal']}' → column `{e['column']}` (stored value: {e['value']!r}){tag}") + if hierarchy_intent: + out.append( + "NOTE: this is a hierarchy / team / reporting question. Filter the column where the " + "entity REPEATS across rows (the parent/manager key), NOT where it appears once (its " + "own identity row). Prefer the table with the MOST matching rows over a stale one." + ) + top = max( + (e for e in resolved_entities if (e.get("match_count") or 0) > 1), + key=lambda e: e.get("match_count") or 0, + default=None, + ) + if top is not None: + out.append( + f" >> Best binding: filter on `{top['column']}` (the entity matches " + f"{top['match_count']} rows there) — query THAT column in THAT table; ignore " + f"tables/columns where the name appears only 0–1 times." + ) + out.append("") + + if resolved_groups: + out.append("# Resolved value groups (a term that spans several stored values)") + out.append( + "Each question term below is an umbrella over MULTIPLE values of one column — it " + "means the WHOLE set, never a subset:" + ) + for g in resolved_groups[:12]: + col = g["column"].rsplit(".", 1)[-1] + vlist = ", ".join(repr(v) for v in g["values"]) + if g.get("truncated"): + out.append( + f"- '{g['literal']}' spans column `{col}` (the catalogued list is INCOMPLETE) — " + f"use `{col} LIKE '{g['literal']}%'` to capture every member." + ) + else: + out.append( + f"- '{g['literal']}' spans these values in `{col}`: [{vlist}] — " + f"use IN(all of them) or `{col} LIKE '{g['literal']}%'`; never a subset." + ) + out.append("") + inv = schema_inventory or [] inv_tables = [h for h in inv if (getattr(h, "metadata", {}) or {}).get("kind") == "TABLE"] - # Map qualified_name -> value fingerprint so the generator copies - # filter/CASE literals verbatim from real values rather than guessing. - value_index: dict[str, str] = {} - for h in inv: - md = getattr(h, "metadata", None) or {} - if md.get("kind") == "COLUMN" and md.get("values"): - value_index[md.get("qualified_name")] = md.get("values") if inv_tables: out.append(f"# Complete dataset catalogue ({len(inv_tables)} tables)") out.append( @@ -276,13 +321,54 @@ def _render_generation_prompt( for t in g_tables: out.append(f"- `{getattr(t, 'table_qualified_name', t)}`") out.append("") - if g_columns: - out.append("## Columns in scope (with real values — copy literals verbatim)") + + # Full, value-anchored column listing for every in-scope table. This both + # raises recall (the right column may have ranked below top-12) and gives + # the SQL writer the real literal vocabulary for each column. + rendered = 0 + table_qns = [getattr(t, "table_qualified_name", str(t)) for t in g_tables] + if not table_qns: + # No grounded tables: fall back to every table that has columns. + table_qns = list(cols_by_table.keys()) + if table_qns: + out.append("## Columns with values (use these literals verbatim)") + for tqn in table_qns: + cols = cols_by_table.get(tqn) or [] + if not cols: + # match by table-name tail when grounding used a short qn + tail = tqn.rsplit(".", 1)[-1].lower() + for k, v in cols_by_table.items(): + if k.rsplit(".", 1)[-1].lower() == tail: + cols = v + break + if not cols: + continue + out.append(f"### `{tqn}`") + for c in cols: + if rendered >= max_columns: + out.append("- … (column list truncated)") + break + out.append( + "- " + + value_anchoring.render_catalog_from_meta( + c["qualified_name"], c, char_budget=char_budget + ) + ) + rendered += 1 + out.append("") + if rendered >= max_columns: + break + elif g_columns: + out.append("## Columns in scope") for c in g_columns: - cqn = getattr(c, "column_qualified_name", c) - vals = value_index.get(cqn) - out.append(f"- `{cqn}`" + (f" :: {vals}" if vals else "")) + qn = getattr(c, "column_qualified_name", str(c)) + meta = _lookup_col(qn, cat_by_qn) + if meta: + out.append("- " + value_anchoring.render_catalog_from_meta(qn, meta, char_budget=char_budget)) + else: + out.append(f"- `{qn}`") out.append("") + if g_joins: out.append("## Approved joins") for j in g_joins: @@ -293,62 +379,34 @@ def _render_generation_prompt( ) out.append("") - # Full column-value catalogue -- the grounding agent may under-select - # columns, so expose every column's real values here too. This is the - # source of truth for WHERE / CASE literals. - if value_index: - out.append(f"# Column value catalogue ({len(value_index)} columns)") - out.append( - "Real values per column. Copy filter/CASE literals VERBATIM from these. " - "If the question names an entity that is not a column name, filter the " - "column whose values contain it." - ) - for qn, vals in value_index.items(): - out.append(f"- `{qn}` :: {vals}") + if examples: + out.append(f"# Worked examples (approved Q→SQL, {len(examples)})") + for h in examples[:5]: + md = getattr(h, "metadata", None) or {} + out.append(f"- Q: {md.get('question', '')}\n SQL: `{md.get('generated_sql', '')}`") out.append("") out.append("# Task") out.append( - "Generate up to N candidate DuckDB SQL queries that answer the question, " + f"Generate {n_candidates} candidate DuckDB SQL queries that answer the question, " "ordered by confidence (highest first). Each candidate must:\n" - "- Use only the tables and columns listed above. " - "Reference each table by its **unqualified name** (the last segment of the " - "qualified name shown above), e.g. write `FROM IVI_MALAGA_SL__Activos` " - "rather than `FROM orbis_companies.IVI_MALAGA_SL__Activos`.\n" - "- Quote any column name that isn't a plain identifier (e.g. date-shaped " - 'names like `2024-12-31` must be `"2024-12-31"`).\n' - "- Copy every WHERE / CASE / IN literal VERBATIM from the column value " - "catalogue above -- never invent or reformat a value (a year is `FY23`, " - "not `2023`; a market may be `Brazil` or `44000BR Brazil` -- use exactly " - "what is listed).\n" - "- When the metric the user names is not a column but appears among a " - "column's listed values, filter that column (tall/EAV layout): e.g. P&L " - "line items like `Total Revenue` / `Manpower` are VALUES of a single " - "category column, selected with `CASE WHEN \"\" = 'Total Revenue' …`.\n" - "- When two numeric columns are near-duplicates whose ranges differ by a " - "constant factor (~10^k), they are the same measure at different scales -- " - "prefer the larger-magnitude one for monetary sums.\n" - "- Do NOT add a WHERE filter on a dimension the question did not ask to " - "slice by -- aggregate across ALL of its values, and do NOT drop a row " - "just because a category value's name contains 'total'/'all' (those are " - "usually legitimate, often 'unallocated', buckets). Exclude a value only " - "if you can confirm it is literally the sum of the other rows.\n" - "- Return what is ASKED FOR: if the question asks for names, a list, " - "'who', 'which', or 'dame los nombres/quiénes', SELECT the identifying " - "column(s) (e.g. the name) and return the matching ROWS -- do NOT collapse " - "to a COUNT. Use COUNT/aggregates only when a count or total is requested. " - "If BOTH a count and the names are asked, return the names (the count is " - "derivable from the row count).\n" - "- HIERARCHY questions: when the question asks about a person's TEAM, " - "direct reports, the people 'at their charge' / under them, their org, or " - "movements in THEIR structure, it is a self-referencing hierarchy. A " - "column marked 'HIERARCHY: holds entities/people from column X' holds each " - "row's manager/owner. The person's team = the ROWS where such a column " - "equals that person (filter it with case-insensitive LIKE '%name%'), NOT " - "the person's own row. Try EVERY hierarchy column (a person may appear in " - "more than one), and match names tolerantly (accents/spacing).\n" - "- Be a SINGLE statement (no multi-statement; no DDL).\n" - "- Be a SELECT (DuckDB-flavored)." + "- Use only the tables and columns listed above. Reference each table by its " + "**unqualified name** (the last segment of the qualified name).\n" + "- For every WHERE / GROUP BY / JOIN literal, COPY a value shown in the " + "'values:' list of that column VERBATIM (exact case + spelling). Do NOT invent " + "or translate filter literals. If the question's entity matches a value under a " + "DIFFERENT column than you expected, filter on the column that actually holds it " + "(see 'Resolved entities').\n" + "- For a year/period/time column whose values are strings (e.g. 'FY23'), filter " + "with the STRING form shown -- never an integer like 2023, and never use a numeric " + "measure column as the year axis.\n" + "- Aggregate measure columns; filter/group on dimension/time columns.\n" + "- NEVER emit a no-op query (e.g. `WHERE 1=0`, `SUM(CASE WHEN ... THEN 0 ELSE 0 END)`, " + "or a constant SELECT). If you cannot map a needed filter to a real value, widen or " + "omit that filter and lower your confidence rather than returning a placeholder.\n" + '- Quote any column name that isn\'t a plain identifier (e.g. `"2024-12-31"`, ' + '`"P&L Line"`).\n' + "- Be a SINGLE SELECT statement (no multi-statement; no DDL)." ) return "\n".join(out) @@ -360,20 +418,30 @@ def _render_critic_prompt( error_message: str, grounded: Any, schema_inventory: list[Any] | None = None, + value_hints: list[str] | None = None, ) -> str: out: list[str] = [] out.append("# User question") out.append(question.strip()) out.append("") - out.append("# Failing SQL") + out.append("# Previous SQL (needs repair)") out.append("```sql") out.append(failing_sql.strip()) out.append("```") out.append("") - out.append("# DuckDB execution error") + out.append("# Problem") out.append(error_message.strip()) out.append("") + # The actual stored values of the columns the failing query filtered on -- + # so the critic replaces wrong literals with real ones instead of guessing + # again (this is what fixes the silent 0-row failures). + if value_hints: + out.append("# Column value catalogue (verify EVERY filter literal against these)") + for line in value_hints: + out.append(f"- {line}") + out.append("") + # Full dataset catalogue -- the critic's most common failure mode # is rewriting one hallucinated table name into another (e.g. # ``balance_sheet`` -> ``financials.balance_sheet`` rather than @@ -610,26 +678,28 @@ async def answer( grounded = getattr(grounded_run, "output", grounded_run) # ------------------------------------------------------------------ - # 4. SQL generation (semantic-layer fast path OR synthesis) + # 4. Value catalogue + entity / group resolution (G1/G3 + round-2 #1/#2) + # ------------------------------------------------------------------ + col_catalog = bundle.get("column_catalog") or [] + resolution = await self._resolve_entities(dataset_id, question, col_catalog, bundle) + resolved_entities = resolution["entities"] + resolved_groups = resolution["groups"] + hierarchy_intent = resolution["hierarchy_intent"] + + # ------------------------------------------------------------------ + # 5. SQL generation (semantic-layer fast path OR synthesis) # ------------------------------------------------------------------ chosen_sql: str candidates_json: list = [] + candidate_sqls: list[str] = [] if grounded.path == "SEMANTIC_LAYER" and grounded.metrics: - # Fetch + bind the published metric's compiled SQL. When found, the - # bound SQL goes straight to the AST firewall + executor — the - # GenerationAgent is NOT invoked — and the metric version is pinned - # into the persisted query record for reproducibility. metric_name = grounded.metrics[0].metric_name compiled, metric_version = await self._compiled_metric_sql( - metric_name, - dataset_id, - tenant_id=tenant_id, - workspace_id=workspace_id, - extra_filter=getattr(grounded.metrics[0], "extra_filter", None), + metric_name, dataset_id, tenant_id=tenant_id, workspace_id=workspace_id ) if compiled: - chosen_sql = compiled + candidate_sqls = [compiled] candidates_json = [ { "sql": compiled, @@ -639,185 +709,177 @@ async def answer( "metric_version": metric_version, } ] - else: - # Fall through to synthesis if no compiled SQL found - gen_prompt = _render_generation_prompt( - question, - grounded, - starting_point_sql, - schema_inventory=bundle.get("schema_inventory"), - ) - gen_run = await self._generation_agent.run(gen_prompt) - gen_out = getattr(gen_run, "output", gen_run) - candidates_json = [c.model_dump() for c in gen_out.candidates] - chosen_sql = gen_out.candidates[0].sql - else: + if not candidate_sqls: gen_prompt = _render_generation_prompt( question, grounded, starting_point_sql, schema_inventory=bundle.get("schema_inventory"), + col_catalog=col_catalog, + resolved_entities=resolved_entities, + resolved_groups=resolved_groups, + hierarchy_intent=hierarchy_intent, + examples=bundle.get("examples"), + n_candidates=self._settings.generation_candidates, + max_columns=self._settings.value_catalog_max_columns, + char_budget=self._settings.value_catalog_char_budget, ) gen_run = await self._generation_agent.run(gen_prompt) gen_out = getattr(gen_run, "output", gen_run) candidates_json = [c.model_dump() for c in gen_out.candidates] - # Don't blindly take the highest-confidence candidate: probe them - # (DuckDB only, no extra LLM) and prefer one that passes the firewall - # and returns non-empty, non-degenerate rows. Empty candidate list - # degrades to "" (FAILED) instead of raising an IndexError. - chosen_sql = await self._select_best_candidate( - [c.sql for c in gen_out.candidates if c.sql], - dataset_id=dataset_id, - scopes=scopes, - dataset_allowlist=dataset_allowlist, - pins=prior_snapshot_pins, - ) + candidate_sqls = [c.sql for c in gen_out.candidates if getattr(c, "sql", None)] + if not candidate_sqls: + candidate_sqls = [""] # ------------------------------------------------------------------ - # 5. AST classify + scope guard + # 6. Candidate selection by EXECUTION (G5): drop degenerate, prefer a + # candidate that runs AND returns rows. Each candidate passes the + # scope guard + synthesis function firewall (G8) + table guard (G6). # ------------------------------------------------------------------ - ast = self._ast_classifier.classify(chosen_sql) - table_kinds = await self._table_kinds_by_name(list(ast.table_refs), dataset_id) - dataset_of_table = await self._dataset_of_tables(list(ast.table_refs), dataset_id) + ordered = [s for s in candidate_sqls if not value_anchoring.is_degenerate_sql(s)] or candidate_sqls + select_pool = ordered if self._settings.candidate_exec_selection else ordered[:1] + + best: tuple[str, Any, Any] | None = None + scope_reject: ScopeGuardError | None = None + firewall_reject: str | None = None + for cand in select_pool: + run = await self._run_sql_once(cand, dataset_id, scopes, dataset_allowlist, bundle) + if run["status"] == "scope": + scope_reject = run["error"] + continue + if run["status"] == "firewall": + firewall_reject = run["error"] + continue + if best is None: + best = (cand, run["result"], run["ast"]) + res = run["result"] + if ( + isinstance(res, ExecutionResult) + and res.row_count > 0 + and not value_anchoring.is_degenerate_sql(cand) + ): + best = (cand, res, run["ast"]) + break - try: - self._scope_guard.check( - classification=ast, - scopes=scopes, - table_kinds_by_name=table_kinds, - dataset_allowlist=dataset_allowlist, - dataset_of_table=dataset_of_table, - ) - except ScopeGuardError as exc: - elapsed = (time.monotonic_ns() // 1_000_000) - start_ms - query_id = await self._query_repo.create_query( - tenant_id=tenant_id, - workspace_id=workspace_id, - dataset_id=dataset_id, - question=question, - semantic_path_taken=grounded.path, - candidates_json=candidates_json, - chosen_candidate_index=0, - executed_sql=chosen_sql, - ast_classification=ast.classification, - execution_status="REJECTED_BY_FIREWALL", - retries=0, - row_count=None, - elapsed_ms=elapsed, - clarification_emitted=False, - clarification_json=None, - pii_findings_json=None, - error_json={"scope_error": str(exc)}, + if best is None: + # Nothing executed. If the only blockers were scope/firewall, reject. + if scope_reject is not None or firewall_reject is not None: + reason = str(scope_reject) if scope_reject is not None else (firewall_reject or "") + chosen_sql = ordered[0] + ast = self._ast_classifier.classify(chosen_sql) + elapsed = (time.monotonic_ns() // 1_000_000) - start_ms + query_id = await self._query_repo.create_query( + tenant_id=tenant_id, + workspace_id=workspace_id, + dataset_id=dataset_id, + question=question, + semantic_path_taken=grounded.path, + candidates_json=candidates_json, + chosen_candidate_index=0, + executed_sql=chosen_sql, + ast_classification=ast.classification, + execution_status="REJECTED_BY_FIREWALL", + retries=0, + row_count=None, + elapsed_ms=elapsed, + clarification_emitted=False, + clarification_json=None, + pii_findings_json=None, + error_json={"firewall_error": reason}, + ) + return AnswerResult( + query_id=query_id, + sql=chosen_sql, + execution_status="REJECTED_BY_FIREWALL", + preview=None, + row_count=None, + truncated=False, + elapsed_ms=elapsed, + chart_hint=None, + explanation=None, + clarification=self._clarification(grounded), + grounded_summary=self._grounded_summary(grounded), + ) + chosen_sql = ordered[0] + best = ( + chosen_sql, + ExecutionError(message="generation produced no usable SQL"), + self._ast_classifier.classify(chosen_sql), ) - return AnswerResult( - query_id=query_id, - sql=chosen_sql, - execution_status="REJECTED_BY_FIREWALL", - preview=None, - row_count=None, - truncated=False, - elapsed_ms=elapsed, - chart_hint=None, - explanation=None, - clarification=self._clarification(grounded), - grounded_summary=self._grounded_summary(grounded), + + chosen_sql, result, ast = best + chosen_index = candidate_sqls.index(chosen_sql) if chosen_sql in candidate_sqls else 0 + retries = 0 + + def _is_good(r, s) -> bool: + return ( + isinstance(r, ExecutionResult) + and r.row_count > 0 + and not value_anchoring.is_degenerate_sql(s) ) + # round-2 #4: remember the best correct result so an ADVISORY repair can + # never replace a correct answer with a worse one. + best_ok = (chosen_sql, result, ast) if _is_good(result, chosen_sql) else None + + # round-2 #5B/#7: when the result is already OK, an advisory check may still + # ask the critic to reconsider (signed-measure double-subtraction; an IN-list + # that under-covers a value group). It rides the non-destructive loop below, + # so it can only ever improve the answer or be ignored. + adv_msg = None + if _is_good(result, chosen_sql): + adv_msg = await self._advisory_repair(chosen_sql, ast, col_catalog, question, dataset_id, bundle) + # ------------------------------------------------------------------ - # 6. Resolve parquet paths + execute (with critic loop) + # 7. Repair loop: HARD repair (error / degenerate / 0-row-with-filter) OR an + # ADVISORY reconsideration. The critic receives the REAL stored values of + # the filtered columns so it replaces wrong literals/columns/signs instead + # of guessing again, and MAY keep the original SQL if it is already correct. # ------------------------------------------------------------------ - # Pre-execution guard: detect SQL that references a table not in - # the dataset. ``_table_kinds_by_name`` returns ONLY matching - # rows -- a missing table simply has no key in the dict. So - # the bad-tables check is set-difference against - # ``ast.table_refs``, not a None-value sweep. - # - # The LLM occasionally hallucinates ``balance_sheet`` / - # ``income_statement`` / ``financials.*`` despite the prompt - # rules; rather than waiting for DuckDB to emit a generic - # ``Catalog Error: Table does not exist``, we synthesise a - # sharp ExecutionError that the existing critic loop picks - # up. The critic prompt receives the full dataset catalogue - # so the rewrite has the real table names in scope. - ref_set = {t for t in ast.table_refs if t} - bad_tables = sorted(ref_set - set(table_kinds.keys())) - if bad_tables: - real_tables = sorted(table_kinds.keys()) + [ - (getattr(h, "metadata", {}) or {}).get("qualified_name", "").rsplit(".", 1)[-1] - for h in (bundle.get("schema_inventory") or []) - if (getattr(h, "metadata", {}) or {}).get("kind") == "TABLE" - ] - real_tables = [t for t in dict.fromkeys(real_tables) if t] - result: ExecutionResult | ExecutionError = ExecutionError( - message=( - f"Table(s) {bad_tables!r} do not exist in this dataset. " - f"Pick ONLY from this catalogue (and translate as needed: " - f"Spanish `Activos` = Assets, `Cuenta de Pérdidas y Ganancias` " - f"= Profit & Loss): {real_tables[:80]!r}." - ) - ) - else: - attached = await self._table_resolver.resolve( - dataset_id, - list(ast.table_refs), - pins=prior_snapshot_pins, - ) - result = await self._executor.execute(chosen_sql, attached) - retries = 0 - - while isinstance(result, ExecutionError) and retries < self._settings.max_refine_retries: + while ( + self._needs_repair(result, chosen_sql) or adv_msg + ) and retries < self._settings.max_refine_retries: + if self._needs_repair(result, chosen_sql): + critic_msg = self._repair_message(result, chosen_sql, ast, col_catalog) + else: + critic_msg = adv_msg + value_hints = self._build_value_hints(chosen_sql, col_catalog, ast) critic_prompt = _render_critic_prompt( question=question, failing_sql=chosen_sql, - error_message=result.message, + error_message=critic_msg, grounded=grounded, schema_inventory=bundle.get("schema_inventory"), + value_hints=value_hints, ) refined_run = await self._critic_agent.run(critic_prompt) refined = getattr(refined_run, "output", refined_run) - chosen_sql = refined.sql - ast = self._ast_classifier.classify(chosen_sql) - table_kinds = await self._table_kinds_by_name(list(ast.table_refs), dataset_id) - dataset_of_table = await self._dataset_of_tables(list(ast.table_refs), dataset_id) - try: - self._scope_guard.check( - classification=ast, - scopes=scopes, - table_kinds_by_name=table_kinds, - dataset_allowlist=dataset_allowlist, - dataset_of_table=dataset_of_table, - ) - except ScopeGuardError: + new_sql = getattr(refined, "sql", None) or chosen_sql + if new_sql.strip() == chosen_sql.strip(): + break # critic kept the SQL -> accept it as already correct + if value_anchoring.is_degenerate_sql(new_sql): + break # never adopt a placeholder; best_ok (if any) is restored below + chosen_sql = new_sql + run = await self._run_sql_once(chosen_sql, dataset_id, scopes, dataset_allowlist, bundle) + if run["status"] in ("scope", "firewall"): break - # Re-apply the unknown-table guard on the refined SQL too -- - # otherwise the critic could hallucinate a different - # non-existent table and DuckDB would catch it generically. - ref_set = {t for t in ast.table_refs if t} - bad_tables = sorted(ref_set - set(table_kinds.keys())) - if bad_tables: - real_tables = sorted(table_kinds.keys()) + [ - (getattr(h, "metadata", {}) or {}).get("qualified_name", "").rsplit(".", 1)[-1] - for h in (bundle.get("schema_inventory") or []) - if (getattr(h, "metadata", {}) or {}).get("kind") == "TABLE" - ] - real_tables = [t for t in dict.fromkeys(real_tables) if t] - result = ExecutionError( - message=( - f"Refined SQL still references missing table(s) " - f"{bad_tables!r}. The dataset only contains: " - f"{real_tables[:80]!r}." - ) - ) - retries += 1 - continue - attached = await self._table_resolver.resolve( - dataset_id, list(ast.table_refs), pins=prior_snapshot_pins - ) - result = await self._executor.execute(chosen_sql, attached) + result = run["result"] + ast = run["ast"] retries += 1 + if _is_good(result, chosen_sql): + best_ok = (chosen_sql, result, ast) + adv_msg = await self._advisory_repair( + chosen_sql, ast, col_catalog, question, dataset_id, bundle + ) + else: + adv_msg = None + + # round-2 #4: if repair degraded a previously-correct answer, restore it. + if best_ok is not None and not _is_good(result, chosen_sql): + chosen_sql, result, ast = best_ok # ------------------------------------------------------------------ - # 7. Determine execution status + # 8. Determine execution status # ------------------------------------------------------------------ if isinstance(result, ExecutionResult) and retries == 0: execution_status: str = "OK" @@ -841,27 +903,10 @@ async def answer( explanation_obj = getattr(explanation_run, "output", explanation_run) # ------------------------------------------------------------------ - # 9. Clarification frame (emitted alongside answer when confidence is low) + # 9. Clarification frame (low confidence, OR confident-but-still-empty) # ------------------------------------------------------------------ - clarification = self._clarification(grounded) - # A syntactically-valid query that returns 0 rows (or a single - # all-NULL/zero aggregate) is suspicious: the usual cause is a - # filter literal that doesn't match how the data is encoded. Rather - # than report it as a confident empty answer, surface a clarification - # and downgrade confidence so the caller knows to verify. - suspicious_empty = isinstance(result, ExecutionResult) and self._is_suspicious_empty(result) - if clarification is None and suspicious_empty: - from flyquery.interfaces.query import ClarificationFrame - - clarification = ClarificationFrame( - questions=[ - "The query executed successfully but returned no matching data " - "(0 rows / empty result). The filter values may not match how the " - "data is encoded -- please verify the exact column values (e.g. " - "category labels or period format) or rephrase the question." - ], - reasons=[], - ) + final_row_count = result.row_count if isinstance(result, ExecutionResult) else None + clarification = self._clarification(grounded, row_count=final_row_count) clarification_emitted = clarification is not None # ------------------------------------------------------------------ @@ -879,7 +924,7 @@ async def answer( question=question, semantic_path_taken=grounded.path, candidates_json=candidates_json, - chosen_candidate_index=0, + chosen_candidate_index=chosen_index, executed_sql=chosen_sql, ast_classification=ast.classification, execution_status=execution_status, @@ -911,12 +956,7 @@ async def answer( # ------------------------------------------------------------------ # 12. Auto-learn (only on first-shot OK + no PII + no clarification) # ------------------------------------------------------------------ - if ( - execution_status == "OK" - and isinstance(result, ExecutionResult) - and result.row_count > 0 - and not clarification_emitted - ): + if execution_status == "OK" and isinstance(result, ExecutionResult) and not clarification_emitted: await self._auto_learner.maybe_propose( tenant_id=tenant_id, workspace_id=workspace_id, @@ -926,21 +966,12 @@ async def answer( retries=retries, pii_findings=[], query_id=query_id, + row_count=result.row_count, ) # ------------------------------------------------------------------ # 13. Persist conversation turn (Phase E drill-down) # ------------------------------------------------------------------ - # THIS turn's snapshot pins: the snapshot each resolved table was - # answered against. Tables already pinned by an earlier turn keep - # their pin (prior wins); newly-referenced tables pin to current. - # Persisting THIS turn's pins (not the prior turn's) is what makes - # drill-down reproducible across a mid-conversation re-ingest. - this_turn_pins: dict[str, str] = { - **(await self._table_resolver.current_snapshots(dataset_id, list(ast.table_refs))), - **prior_snapshot_pins, - } - if ( conversation_id is not None and self._conversation_service is not None @@ -954,7 +985,7 @@ async def answer( executed_sql=chosen_sql, summary=explanation_obj.summary if explanation_obj else None, table_qnames_json=list(ast.table_refs), - snapshot_pins_json=this_turn_pins, + snapshot_pins_json=prior_snapshot_pins, elapsed_ms=elapsed, ) @@ -969,85 +1000,13 @@ async def answer( chart_hint=explanation_obj.chart_hint if explanation_obj else None, explanation=explanation_obj.summary if explanation_obj else None, clarification=clarification, - grounded_summary=self._grounded_summary( - grounded, confidence_cap=0.4 if suspicious_empty else None - ), - snapshot_pins=this_turn_pins, + grounded_summary=self._grounded_summary(grounded), ) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ - @staticmethod - def _is_suspicious_empty(result) -> bool: - """True when an executed result is empty/degenerate enough to doubt. - - Catches the canonical wrong-literal symptom: a valid query that - matched nothing (0 rows), or a single-row single-column aggregate - whose only value is NULL / 0 / 0.0 (e.g. a SUM/CASE pivot where - every branch missed). - """ - if result.row_count == 0: - return True - rows = getattr(result, "rows", None) or [] - if result.row_count == 1 and len(rows) == 1 and isinstance(rows[0], dict) and len(rows[0]) == 1: - (only_value,) = rows[0].values() - return only_value is None or only_value == 0 - return False - - async def _select_best_candidate( - self, - candidate_sqls: list[str], - *, - dataset_id: uuid.UUID, - scopes: set[str], - dataset_allowlist: set[uuid.UUID] | None, - pins: dict[str, str], - ) -> str: - """Pick the candidate SQL that best answers the question. - - Generation emits N candidates ranked by self-reported confidence, but - the top one sometimes filters on the wrong column (or under-searches a - set of hierarchy columns) and returns 0 rows while a lower-ranked - candidate is correct. So we probe the candidates and prefer the first - that (a) passes the firewall, (b) executes, and (c) returns non-empty, - non-degenerate rows. This is DuckDB-only -- NO extra LLM calls -- and - general: it just prefers a candidate that actually returns data. - - Falls back to the first candidate that executed at all, else the first - candidate (so the existing scope/critic handling downstream is - unchanged when nothing is clearly better). - """ - if len(candidate_sqls) <= 1: - return candidate_sqls[0] if candidate_sqls else "" - - first_executed: str | None = None - for sql in candidate_sqls: - ast = self._ast_classifier.classify(sql) - table_kinds = await self._table_kinds_by_name(list(ast.table_refs), dataset_id) - dataset_of_table = await self._dataset_of_tables(list(ast.table_refs), dataset_id) - try: - self._scope_guard.check( - classification=ast, - scopes=scopes, - table_kinds_by_name=table_kinds, - dataset_allowlist=dataset_allowlist, - dataset_of_table=dataset_of_table, - ) - except ScopeGuardError: - continue # unsafe candidate -- skip - if sorted({t for t in ast.table_refs if t} - set(table_kinds.keys())): - continue # references a table not in the dataset -- skip - attached = await self._table_resolver.resolve(dataset_id, list(ast.table_refs), pins=pins) - result = await self._executor.execute(sql, attached) - if isinstance(result, ExecutionResult): - if not self._is_suspicious_empty(result): - return sql # passes firewall + returns real rows -- best - if first_executed is None: - first_executed = sql # remember first successful-but-empty - return first_executed or candidate_sqls[0] - async def _table_kinds_by_name( self, table_names: list[str], @@ -1088,19 +1047,11 @@ async def _compiled_metric_sql( *, tenant_id: str, workspace_id: uuid.UUID, - extra_filter: str | None = None, ) -> tuple[str | None, int | None]: """Fetch + bind the compiled SQL for a PUBLISHED metric. Returns ``(bound_sql, current_version)`` so the version can be pinned in the query record, or ``(None, None)`` when no usable metric is found. - - ``extra_filter`` is the per-question slice the grounding agent derived - (e.g. ``Market = 'Brazil' AND Year = 'FY24'``) to be appended to the - metric's WHERE via the compiler's ``{extra_filter_clause}`` slot. It is - an LLM-supplied predicate, so it is re-run through the publish-time - firewall before binding; an unsafe filter is dropped (the metric still - returns its unfiltered value) rather than executed. """ if self._semantic_repo is None: return None, None @@ -1113,55 +1064,438 @@ async def _compiled_metric_sql( return None, None if not row or not row.get("compiled_sql_template"): return None, None - - safe_filter = self._firewall_extra_filter(row["compiled_sql_template"], extra_filter) - bound = SemanticCompiler.bind(row["compiled_sql_template"], extra_filter=safe_filter) + bound = SemanticCompiler.bind(row["compiled_sql_template"]) return bound, row.get("current_version") - @staticmethod - def _firewall_extra_filter(template: str, extra_filter: str | None) -> str | None: - """Validate an LLM-supplied metric filter via the publish-time firewall. - - Returns the filter when the bound SQL passes ``assert_safe_template``, - else ``None`` (filter dropped). Defensive: any firewall/parse failure - also drops the filter rather than risking an unsafe predicate. - """ - if not extra_filter: - return None - try: - from flyquery.core.services.semantic.firewall import assert_safe_template - - probe = SemanticCompiler.bind(template, extra_filter=extra_filter) - assert_safe_template(probe) - return extra_filter - except Exception as exc: # noqa: BLE001 -- any failure → drop the filter - logger.warning("dropping unsafe semantic extra_filter %r: %s", extra_filter, exc) - return None - - def _clarification(self, grounded) -> Any: - """Build a ClarificationFrame if grounding confidence is low.""" + def _clarification(self, grounded, *, row_count: int | None = None) -> Any: + """Build a ClarificationFrame if grounding confidence is low or the + (confident) query still returned 0 rows (G7 -- couple clarification to + the observed empty result instead of confidence alone).""" from flyquery.interfaces.query import ClarificationFrame if grounded.confidence < self._settings.grounding_min_confidence and grounded.missing_info: + return ClarificationFrame(questions=grounded.missing_info, reasons=[]) + if row_count == 0: return ClarificationFrame( - questions=grounded.missing_info, + questions=[ + "The query executed but matched 0 rows -- a filter value, period, " + "or entity may not match how it is stored in the data. Please confirm " + "the exact value you mean." + ], reasons=[], ) return None - def _grounded_summary(self, grounded, confidence_cap: float | None = None) -> dict: - """Convert GroundedContext to a summary dict for the response. - - ``confidence_cap`` lets the caller lower the reported confidence when - the executed result is suspicious (e.g. 0 rows from a wrong literal), - so a confidently-wrong empty answer is not surfaced at high confidence. - """ - confidence = grounded.confidence - if confidence_cap is not None: - confidence = min(confidence, confidence_cap) + def _grounded_summary(self, grounded) -> dict: + """Convert GroundedContext to a summary dict for the response.""" return { "path": grounded.path, - "confidence": confidence, + "confidence": grounded.confidence, "table_count": len(grounded.tables), "missing_info": grounded.missing_info, } + + # ------------------------------------------------------------------ + # Value-anchoring / repair helpers (G2/G3/G5/G6/G8) + # ------------------------------------------------------------------ + + async def _run_sql_once( + self, + sql: str, + dataset_id: uuid.UUID, + scopes: set[str], + dataset_allowlist: set[str] | None, + bundle: dict, + ) -> dict: + """Classify, guard, and execute one candidate SQL. + + Returns ``{status, result, ast, error}`` where status is one of + ``ok`` / ``empty`` / ``error`` / ``scope`` / ``firewall``. + """ + ast = self._ast_classifier.classify(sql) + + # G8: block filesystem/exfiltration table-functions on the synthesis path. + if self._settings.synthesis_function_firewall: + dangerous = value_anchoring.find_dangerous_functions(sql) + if dangerous: + return { + "status": "firewall", + "result": None, + "ast": ast, + "error": f"disallowed function(s) in generated SQL: {sorted(dangerous)}", + } + + table_kinds = await self._table_kinds_by_name(list(ast.table_refs), dataset_id) + dataset_of_table = await self._dataset_of_tables(list(ast.table_refs), dataset_id) + try: + self._scope_guard.check( + classification=ast, + scopes=scopes, + table_kinds_by_name=table_kinds, + dataset_allowlist=dataset_allowlist, + dataset_of_table=dataset_of_table, + ) + except ScopeGuardError as exc: + return {"status": "scope", "result": None, "ast": ast, "error": exc} + + # G6: CTE-aware unknown-table guard. CTE / derived-table aliases are NOT + # real tables, so subtract them before the set-difference -- otherwise a + # valid `WITH x AS (...) SELECT ... FROM x` is wrongly rejected. + synthetic = {n.lower() for n in value_anchoring.cte_and_derived_names(sql)} + ref_set = {t for t in ast.table_refs if t and t.lower() not in synthetic} + bad_tables = sorted(ref_set - set(table_kinds.keys())) + if bad_tables: + real_tables = sorted(table_kinds.keys()) + [ + (getattr(h, "metadata", {}) or {}).get("qualified_name", "").rsplit(".", 1)[-1] + for h in (bundle.get("schema_inventory") or []) + if (getattr(h, "metadata", {}) or {}).get("kind") == "TABLE" + ] + real_tables = [t for t in dict.fromkeys(real_tables) if t] + return { + "status": "error", + "result": ExecutionError( + message=( + f"Table(s) {bad_tables!r} do not exist in this dataset. " + f"Use ONLY these tables: {real_tables[:80]!r}." + ) + ), + "ast": ast, + "error": None, + } + + attached = await self._table_resolver.resolve(dataset_id, list(ast.table_refs)) + result = await self._executor.execute(sql, attached) + if isinstance(result, ExecutionResult): + status = "ok" if result.row_count > 0 else "empty" + else: + status = "error" + return {"status": status, "result": result, "ast": ast, "error": None} + + def _needs_repair(self, result, sql: str) -> bool: + """A query needs repair if it errored, is degenerate, or ran to 0 rows + while filtering on equality/IN literals (G2/G5).""" + if isinstance(result, ExecutionError): + return True + if value_anchoring.is_degenerate_sql(sql): + return True + return bool( + self._settings.zero_row_repair_enabled + and isinstance(result, ExecutionResult) + and result.row_count == 0 + and value_anchoring.equality_predicate_columns(sql) + ) + + def _repair_message(self, result, sql: str, ast, col_catalog: list[dict]) -> str: + unknown = self._unknown_columns(ast, col_catalog, sql) + suffix = "" + if unknown: + names = sorted({c.get("qualified_name", "").rsplit(".", 1)[-1] for c in col_catalog}) + suffix = ( + f" Also: column(s) {unknown!r} are not real columns -- use only these " + f"columns: {names[:80]!r}." + ) + if isinstance(result, ExecutionError): + return result.message + suffix + if value_anchoring.is_degenerate_sql(sql): + return ( + "The previous SQL is a no-op (constant/placeholder, e.g. WHERE 1=0 or " + "SUM(CASE..THEN 0 ELSE 0)). Rewrite it to actually compute the answer " + "using real columns and literal values from the catalogue below." + suffix + ) + return ( + "The previous SQL executed but returned 0 ROWS. One or more filter literals " + "or filtered columns is wrong. Replace each WHERE/IN literal with a value that " + "actually appears in that column (see the value catalogue). If the entity belongs " + "to a different column, filter THAT column instead." + suffix + ) + + def _build_value_hints(self, sql: str, col_catalog: list[dict], ast) -> list[str]: + """Render value-catalogue lines for the columns the failing SQL touched.""" + wanted: set[str] = {c.lower() for c in value_anchoring.equality_predicate_columns(sql)} + wanted |= {c.lower() for c in (getattr(ast, "column_refs", ()) or ())} + hints: list[str] = [] + seen: set[str] = set() + for c in col_catalog: + tail = c["qualified_name"].rsplit(".", 1)[-1].lower() + if tail in wanted and tail not in seen: + seen.add(tail) + hints.append(value_anchoring.render_catalog_from_meta(c["qualified_name"], c)) + return hints[:40] + + def _unknown_columns(self, ast, col_catalog: list[dict], sql: str) -> list[str]: + if not col_catalog: + return [] + known = {c["qualified_name"].rsplit(".", 1)[-1].lower() for c in col_catalog} + aliases = value_anchoring.select_aliases(sql) + out: list[str] = [] + for c in getattr(ast, "column_refs", ()) or (): + cl = c.lower() + if cl not in known and cl not in aliases and cl != "*": + out.append(c) + return out[:20] + + async def _advisory_repair(self, sql, ast, col_catalog, question, dataset_id, bundle) -> str | None: + """Non-error reasons to ask the critic to reconsider an already-OK result. + + Rides the non-destructive loop (round-2 #4): under-covered value groups (#7) + and signed-measure double-subtraction (#5B). Returns a combined message or None. + """ + reasons: list[str] = [] + if self._settings.group_coverage_repair_enabled: + lits = value_anchoring.extract_question_literals(question) + for g in value_anchoring.group_coverage_gaps(sql, col_catalog, lits): + col = g["column"].rsplit(".", 1)[-1] + reasons.append( + f"The filter on `{col}` lists only SOME of the values the question's group " + f"covers; it is MISSING {g['missing']!r}. Include ALL of them (IN-list) or use " + f"an anchored LIKE; keep the original only if the exclusion is genuinely intended." + ) + if self._settings.signed_measure_repair_enabled: + hint = await self._signed_measure_hint(sql, ast, col_catalog, dataset_id) + if hint: + reasons.append(hint) + return " ".join(reasons) if reasons else None + + async def _signed_measure_hint(self, sql, ast, col_catalog, dataset_id) -> str | None: + """Probe whether the SQL subtracts a term whose SIGNED measure is already + stored negative (so subtracting double-counts the sign). Best-effort (#5B).""" + try: + mixed = { + c["qualified_name"].rsplit(".", 1)[-1].lower() + for c in col_catalog + if c.get("mixed_sign") and value_anchoring.semantic_role(c.get("semantic_type")) == "measure" + } + if not mixed: + return None + terms = value_anchoring.signed_subtraction_terms(sql) + subtracted = [ + t + for t in terms + if t["sign"] < 0 and t["measure_col"].lower() in mixed and t["dim_col"] and t["literals"] + ] + if not subtracted: + return None + tables = [t for t in ast.table_refs if t] + if not tables: + return None + attached = await self._table_resolver.resolve(dataset_id, tables) + if not attached: + return None + import asyncio + + offenders = await asyncio.to_thread(_probe_signed_terms, attached, subtracted) + if not offenders: + return None + parts = "; ".join(f"{o['dim']} IN {o['literals']} sums to {o['sum']:.0f}" for o in offenders[:6]) + # build the corrected single-signed-sum pattern from ALL the formula's + # terms (the generic fix: the data already carries the sign). + all_terms = [ + t for t in terms if t["measure_col"].lower() in mixed and t["dim_col"] and t["literals"] + ] + allowed = sorted({lit for t in all_terms for lit in t["literals"]}) + dim = all_terms[0]["dim_col"] + meas = all_terms[0]["measure_col"] + lits_sql = ", ".join("'" + lit.replace("'", "''") + "'" for lit in allowed) + return ( + f"SIGN ERROR (must fix): this formula SUBTRACTS terms whose measure is ALREADY STORED " + f"NEGATIVE ({parts}). Subtracting an already-negative value double-counts the sign and " + f"inflates the result above total revenue. The data already carries the economic sign, so " + f"REWRITE the whole +/- expression as ONE signed sum over all its line-items: " + f'SUM(CASE WHEN "{dim}" IN ({lits_sql}) THEN "{meas}" ELSE 0 END) -- keep the original ' + f"WHERE filters and any GROUP BY. Do NOT keep the subtraction chain." + ) + except Exception as exc: # noqa: BLE001 - best-effort + logger.debug("signed-measure probe failed: %s", exc) + return None + + async def _resolve_entities( + self, + dataset_id: uuid.UUID, + question: str, + col_catalog: list[dict], + bundle: dict, + ) -> dict: + """Map question literals to the columns that actually store them (G3). + + First via the low-cardinality value catalogue (cheap), then -- for + unresolved entity-looking literals (e.g. person names in a high-card + column) -- via a bounded live DuckDB scan of the dataset's tables. + """ + empty = {"entities": [], "groups": [], "hierarchy_intent": False} + if not self._settings.entity_resolution_enabled: + return empty + literals = value_anchoring.extract_question_literals(question) + if not literals: + return empty + + hierarchy = value_anchoring.relationship_intent(question) + # round-2 #1: a term that umbrellas >=2 values of one column -> a group. + groups: list[dict] = [] + if self._settings.group_resolution_enabled: + groups = value_anchoring.resolve_value_groups(literals, col_catalog) + group_keys = {(g["literal"].lower(), g["column"]) for g in groups} + group_lits = {g["literal"].lower() for g in groups} + + resolved = value_anchoring.resolve_from_catalog(literals, col_catalog) + # a single value superseded by a group is dropped (the group is the truth) + resolved = [r for r in resolved if (r["literal"].lower(), r["column"]) not in group_keys] + done = {r["literal"].lower() for r in resolved} | group_lits + remaining = [lit for lit in literals if lit.lower() not in done][ + : self._settings.entity_resolution_max_literals + ] + if remaining: + try: + resolved += await self._scan_columns_for_values(dataset_id, remaining, bundle) + except Exception as exc: # noqa: BLE001 - resolution is best-effort + logger.warning("live entity resolution failed: %s", exc) + + # de-dup entities by (literal, column), keeping the highest match_count + merged: dict[tuple[str, str], dict] = {} + for r in resolved: + key = (r["literal"].lower(), r["column"]) + cur = merged.get(key) + if cur is None or (r.get("match_count") or 0) > (cur.get("match_count") or 0): + merged[key] = r + # de-dup groups by (literal, column) + gseen: set[tuple[str, str]] = set() + gout: list[dict] = [] + for g in groups: + k = (g["literal"].lower(), g["column"]) + if k not in gseen: + gseen.add(k) + gout.append(g) + entities = sorted(merged.values(), key=lambda r: r.get("match_count") or 0, reverse=True)[:25] + return {"entities": entities, "groups": gout[:12], "hierarchy_intent": hierarchy} + + async def _scan_columns_for_values( + self, + dataset_id: uuid.UUID, + literals: list[str], + bundle: dict, + ) -> list[dict]: + table_names = [ + (getattr(h, "metadata", {}) or {}).get("qualified_name", "").rsplit(".", 1)[-1] + for h in (bundle.get("schema_inventory") or []) + if (getattr(h, "metadata", {}) or {}).get("kind") == "TABLE" + ] + table_names = [t for t in dict.fromkeys(table_names) if t] + if not table_names: + return [] + attached = await self._table_resolver.resolve(dataset_id, table_names) + if not attached: + return [] + import asyncio + + return await asyncio.to_thread(_scan_sync, attached, literals) + + +def _probe_signed_terms(attached_tables: dict[str, str], terms: list[dict]) -> list[dict]: + """Sum each subtracted term's measure over its line-items; report the ones that + sum NEGATIVE (already stored negative → subtracting double-counts). Best-effort.""" + try: + import duckdb + except ImportError: # pragma: no cover + return [] + paths = list(attached_tables.values()) + if not paths: + return [] + offenders: list[dict] = [] + conn = None + try: + conn = duckdb.connect(":memory:") + conn.execute("SET threads=2") + for t in terms: + mq = t["measure_col"].replace(chr(34), chr(34) * 2) + dq = t["dim_col"].replace(chr(34), chr(34) * 2) + in_list = ", ".join("'" + str(lit).replace("'", "''") + "'" for lit in t["literals"]) + total = None + for path in paths: + try: + cols = { + c[0].lower() + for c in conn.execute(f"DESCRIBE SELECT * FROM read_parquet('{path}')").fetchall() + } + if t["measure_col"].lower() not in cols or t["dim_col"].lower() not in cols: + continue + s = conn.execute( + f"SELECT SUM(TRY_CAST(\"{mq}\" AS DOUBLE)) FROM read_parquet('{path}') " + f'WHERE CAST("{dq}" AS VARCHAR) IN ({in_list})' + ).fetchone() + if s and s[0] is not None: + total = (total or 0.0) + float(s[0]) + except Exception: # noqa: BLE001 + continue + if total is not None and total < 0: + offenders.append({"dim": t["dim_col"], "literals": t["literals"], "sum": total}) + except Exception: # noqa: BLE001 + return offenders + finally: + if conn is not None: + conn.close() + return offenders + + +def _scan_sync(attached_tables: dict[str, str], literals: list[str]) -> list[dict]: + """Find which column stores each question literal: one bounded scan per table. + + For each table we run a single pass that, per VARCHAR column, returns the + first stored value equal (case-insensitively) to any of the literals. This + resolves high-cardinality entities (person/brand names) the low-cardinality + value catalogue can't carry. Best-effort + dataset-agnostic. + """ + try: + import duckdb + except ImportError: # pragma: no cover + return [] + lits = [s.lower() for s in literals if s] + if not lits: + return [] + in_list = ", ".join("'" + s.replace("'", "''") + "'" for s in lits) + found: list[dict] = [] + seen: set[tuple[str, str]] = set() + for tname, path in attached_tables.items(): + conn = None + try: + conn = duckdb.connect(":memory:") + conn.execute("SET threads=2") + desc = conn.execute(f"DESCRIBE SELECT * FROM read_parquet('{path}')").fetchall() + vcols = [ + d[0] for d in desc if str(d[1]).upper().startswith(("VARCHAR", "TEXT", "STRING", "CHAR")) + ][:80] + if not vcols: + continue + cols_sql = [] + for i, c in enumerate(vcols): + cq = c.replace(chr(34), chr(34) * 2) + pred = f'lower(CAST("{cq}" AS VARCHAR)) IN ({in_list})' + cols_sql.append(f'MAX(CASE WHEN {pred} THEN CAST("{cq}" AS VARCHAR) END) AS m{i}') + # match_count: how many rows this column has the literal in. A value + # that REPEATS is a grouping/parent key; one that appears once is an + # identity. Same single scan -- no extra cost. + cols_sql.append(f"SUM(CASE WHEN {pred} THEN 1 ELSE 0 END) AS c{i}") + row = conn.execute(f"SELECT {', '.join(cols_sql)} FROM read_parquet('{path}')").fetchone() + if not row: + continue + for i, c in enumerate(vcols): + val = row[2 * i] + cnt = row[2 * i + 1] + if val is None: + continue + key = (str(val).lower(), f"{tname}.{c}") + if key in seen: + continue + seen.add(key) + found.append( + { + "literal": str(val), + "column": f"{tname}.{c}", + "value": val, + "match_count": int(cnt or 0), + } + ) + except Exception: # noqa: BLE001 - resolution is best-effort + continue + finally: + if conn is not None: + conn.close() + return found diff --git a/src/flyquery/core/services/query/value_anchoring.py b/src/flyquery/core/services/query/value_anchoring.py new file mode 100644 index 0000000..1b8c0b1 --- /dev/null +++ b/src/flyquery/core/services/query/value_anchoring.py @@ -0,0 +1,795 @@ +# Copyright 2024-2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Value-anchoring helpers for the NL->SQL pipeline (dataset-agnostic). + +These pure functions are the backbone of the grounding/generation quality +improvements. They turn the per-column artifacts the ingest pipeline already +computes (``profile_json.top_values`` / ``min`` / ``max`` / +``distinct_estimate``, ``sample_values_json``, ``governance_json.semantic_type``) +into prompt text the SQL writer can copy literals from, resolve question +entities to the columns that store them, and detect degenerate / unsafe SQL. + +Nothing here hardcodes a column name, label, or domain assumption -- every +function operates on whatever the dataset's own profiling produced. +""" + +from __future__ import annotations + +import json +import re +from typing import Any + +import sqlglot +from sqlglot import exp + +# ---------------------------------------------------------------------- +# Semantic role (measure vs dimension vs time) +# ---------------------------------------------------------------------- + +_MEASURE_TYPES = { + "amount", + "count", + "decimal", + "ratio", + "percentage", + "percent", + "rate", + "score", + "currency", + "measure", + "quantity", + "number", + "money", + "metric", +} +_TIME_TYPES = {"year", "quarter", "month", "week", "day", "date", "datetime", "period", "fiscal_year"} +_DIMENSION_TYPES = { + "enum", + "tag", + "name", + "code", + "category", + "identifier", + "id", + "boolean", + "label", + "text", + "status", + "geo", + "country", + "region", +} + + +def semantic_role(semantic_type: str | None) -> str | None: + """Map an ingest-computed semantic_type to measure / time / dimension.""" + st = (semantic_type or "").strip().lower() + if not st: + return None + if st in _TIME_TYPES: + return "time" + if st in _MEASURE_TYPES: + return "measure" + if st in _DIMENSION_TYPES: + return "dimension" + return None + + +def _as_obj(v: Any) -> Any: + """JSONB columns usually arrive decoded; tolerate a str just in case.""" + if isinstance(v, str): + try: + return json.loads(v) + except Exception: # noqa: BLE001 + return None + return v + + +def coerce_profile( + profile_json: Any, + sample_values_json: Any, + *, + max_values: int = 25, +) -> dict: + """Return ``{distinct, values, min, max}`` from the ingest artifacts. + + ``values`` is the de-duplicated, capped list of distinct/top values for a + low-cardinality column (drawn from ``profile_json.top_values`` first, then + ``sample_values_json``). ``min``/``max`` come from the numeric profile. + """ + profile = _as_obj(profile_json) or {} + samples = _as_obj(sample_values_json) or [] + distinct = profile.get("distinct_estimate") if isinstance(profile, dict) else None + mn = profile.get("min") if isinstance(profile, dict) else None + mx = profile.get("max") if isinstance(profile, dict) else None + + raw: list = [] + tv = profile.get("top_values") if isinstance(profile, dict) else None + if isinstance(tv, list): + for item in tv: + if isinstance(item, dict) and "value" in item: + raw.append(item["value"]) + else: + raw.append(item) + if not raw and isinstance(samples, list): + raw = list(samples) + + seen: set[str] = set() + values: list[str] = [] + for v in raw: + s = "NULL" if v is None else str(v) + if s not in seen: + seen.add(s) + values.append(s) + if len(values) >= max_values: + break + # mixed_sign: the column stores BOTH negative and positive numbers (so some + # categories/line-items are already stored as negatives). Derived purely from + # the min/max the profiler already computed -- no domain assumption. + mixed = False + try: + if mn is not None and mx is not None and float(mn) < 0 < float(mx): + mixed = True + except (TypeError, ValueError): + pass + return {"distinct": distinct, "values": values, "min": mn, "max": mx, "mixed_sign": mixed} + + +def render_column_catalog_line( + qualified_name: str, + data_type: str | None, + semantic_type: str | None, + profile_json: Any, + sample_values_json: Any, + *, + max_values: int = 25, + char_budget: int = 280, +) -> str: + """Render one column's value catalogue line for a prompt. + + Examples:: + + `t.Year` (VARCHAR, time, 4 distinct) values: 'FY23','FY24','FY25','FY26' + `t.P&L Line` (VARCHAR, dimension, 20 distinct) values: 'Total Revenue','Manpower',... + `t.FY` (DOUBLE, measure) range: -184.90 .. 361.49 + """ + col = qualified_name.rsplit(".", 1)[-1] + prof = coerce_profile(profile_json, sample_values_json, max_values=max_values) + role = semantic_role(semantic_type) + head_bits = [data_type or "?"] + if role: + head_bits.append(role) + if prof["distinct"] is not None: + head_bits.append(f"{prof['distinct']} distinct") + head = f"`{col}` ({', '.join(head_bits)})" + + if prof["values"]: + vtext = ", ".join(_q(v) for v in prof["values"]) + if len(vtext) > char_budget: + vtext = vtext[:char_budget].rsplit(",", 1)[0] + ", …" + return f"{head} values: {vtext}" + if prof["min"] is not None or prof["max"] is not None: + return f"{head} range: {_num(prof['min'])} .. {_num(prof['max'])}" + return head + + +def render_catalog_from_meta(qualified_name: str, meta: dict, *, char_budget: int = 320) -> str: + """Render a catalogue line from an already-coerced col_catalog item. + + ``meta`` keys: data_type, semantic_type, distinct, values, min, max. + """ + col = qualified_name.rsplit(".", 1)[-1] + bits = [meta.get("data_type") or "?"] + role = semantic_role(meta.get("semantic_type")) + if role: + bits.append(role) + if meta.get("distinct") is not None: + bits.append(f"{meta['distinct']} distinct") + # Echo the source's ORIGINAL header (pre-rename) when the ingest renamed the + # column (e.g. an Excel year header collapsed to year_1). Lets the model know + # what year_1 actually was, instead of relying on a fixed ordinal convention. + oh = meta.get("original_header") + if oh: + bits.append(f"original header: {_q(str(oh))}") + head = f"`{col}` ({', '.join(bits)})" + vals = meta.get("values") or [] + if vals: + vt = ", ".join(_q(str(v)) for v in vals) + if len(vt) > char_budget: + vt = vt[:char_budget].rsplit(",", 1)[0] + ", …" + return f"{head} values: {vt}" + if meta.get("min") is not None or meta.get("max") is not None: + line = f"{head} range: {_num(meta.get('min'))} .. {_num(meta.get('max'))}" + if role == "measure" and meta.get("mixed_sign"): + line += ( + " — SIGNED (stores both negatives and positives; some line-items are already stored negative)" + ) + return line + return head + + +def _q(v: str) -> str: + if v == "NULL": + return "NULL" + return "'" + v.replace("'", "''") + "'" + + +def _num(v: Any) -> str: + try: + return f"{float(v):.2f}" + except (TypeError, ValueError): + return str(v) + + +# ---------------------------------------------------------------------- +# Entity extraction + value->column resolution (G3) +# ---------------------------------------------------------------------- + +_STOP = { + "the", + "and", + "for", + "with", + "que", + "los", + "las", + "del", + "una", + "uno", + "what", + "which", + "how", + "many", + "cual", + "cuales", + "cuantos", + "cuanto", + "team", + "equipo", + "average", + "promedio", + "total", + "year", + "ratio", +} + + +def extract_question_literals(question: str) -> list[str]: + """Pull candidate entity literals out of a NL question (heuristic, generic). + + Captures quoted phrases, ALL-CAPS / Capitalized multi-word runs (proper + nouns, person names, product/brand/business-unit names) and standalone + upper-case codes. Over-extraction is fine -- the resolver only keeps + literals that actually match a stored column value. + """ + out: list[str] = [] + # quoted phrases + for m in re.findall(r"['\"]([^'\"]{2,})['\"]", question): + out.append(m.strip()) + # runs of capitalized / all-caps words (>=1 word), e.g. "FRANCISCO JAVIER ..." + for m in re.findall(r"\b([A-ZÀ-Ý][\wÀ-ý&/.-]*(?:\s+[A-ZÀ-Ý][\wÀ-ý&/.-]*)*)\b", question): + tok = m.strip() + if len(tok) >= 2 and tok.lower() not in _STOP: + out.append(tok) + # standalone all-caps tokens (OBU, DAPA, CVRM, RIV) + for m in re.findall(r"\b([A-Z][A-Z0-9&/]{1,})\b", question): + if m.lower() not in _STOP: + out.append(m) + # de-dup preserving order, drop pure substrings already covered + seen: set[str] = set() + uniq: list[str] = [] + for t in out: + key = t.lower() + if key not in seen: + seen.add(key) + uniq.append(t) + return uniq + + +def resolve_from_catalog( + literals: list[str], + col_catalog: list[dict], +) -> list[dict]: + """Map literals to columns using already-fetched low-cardinality values. + + ``col_catalog`` items: ``{qualified_name, values: [str]}``. Returns + ``[{literal, column, value}]`` for case-insensitive exact matches. + """ + found: list[dict] = [] + seen: set[tuple[str, str]] = set() + for lit in literals: + ll = lit.lower() + for c in col_catalog: + for v in c.get("values", []) or []: + vl = str(v).lower() + if vl == ll or (len(ll) >= 4 and ll in vl) or (len(vl) >= 4 and vl in ll): + key = (lit.lower(), c["qualified_name"]) + if key not in seen: + seen.add(key) + found.append({"literal": lit, "column": c["qualified_name"], "value": v}) + break + return found + + +# Words that signal a hierarchy / "reports-to" question, where the answer is the +# set of rows whose PARENT column equals the entity (the value that REPEATS), not +# the single identity row. Generic NL cues (EN + ES); the load-bearing signal is +# row-match cardinality (this only nudges the choice). Never reads the schema. +_HIERARCHY_INTENT = ( + "team of", + "team for", + "reports to", + "report to", + "direct reports", + "reporting to", + "under ", + "manages", + "managed by", + "supervises", + "supervised by", + "led by", + "in charge of", + "members of", + "people in", + "headcount of", + "size of the team", + "roster of", + "equipo de", + "a cargo de", + "a su cargo", + "reportan a", + "reporta a", + "bajo ", + "dirige", + "dirigido por", + "supervisa", + "supervisado por", + "depende de", + "dependen de", + "miembros de", + "integrantes de", + "dimensionamiento", + "personas en", +) + + +def relationship_intent(question: str) -> bool: + """True if the question is about a hierarchy / team / reporting relationship.""" + q = " " + (question or "").lower() + " " + return any(kw in q for kw in _HIERARCHY_INTENT) + + +def _is_prefix_token(literal: str, value: str) -> bool: + """True if ``literal`` is an anchored token-prefix of ``value``. + + 'Field' prefixes 'Field Sales' / 'Field-Access' (boundary follows), but NOT + 'Fieldwork'. Case-insensitive. This is the structural definition of a term + that fans out to several sub-values. + """ + ll = literal.strip().lower() + vl = str(value).strip().lower() + if not ll or not vl or vl == ll or not vl.startswith(ll): + return False + nxt = vl[len(ll) : len(ll) + 1] + return nxt == "" or not nxt.isalnum() + + +def resolve_value_groups(literals: list[str], col_catalog: list[dict]) -> list[dict]: + """Map a question term to the FULL SET of catalogued values it umbrellas. + + When a literal is an anchored token-prefix of >=2 distinct values of the SAME + column ('Field' -> {'Field Sales','Field Medical','Field Access',...}), return + a group so the model filters the whole set, not a subset. Gated to 'pure + umbrellas': skip a literal that is itself an exact stored value. ``truncated`` + flags that the column has more distinct values than were catalogued (so an + IN-list would be incomplete and an anchored LIKE is safer). Dataset-agnostic: + driven only by the column's own distinct values. + """ + groups: list[dict] = [] + for lit in literals: + ll = lit.strip().lower() + if len(ll) < 2: + continue + for c in col_catalog: + vals = [str(v) for v in (c.get("values") or [])] + if not vals: + continue + if any(str(v).strip().lower() == ll for v in vals): + continue # the term IS a value -> plain entity, not an umbrella + matches = [v for v in vals if _is_prefix_token(lit, v)] + if len({m.lower() for m in matches}) >= 2: + distinct = c.get("distinct") + groups.append( + { + "literal": lit, + "column": c["qualified_name"], + "values": matches, + "truncated": bool(distinct is not None and distinct > len(vals)), + } + ) + return groups + + +# ---------------------------------------------------------------------- +# SQL analysis: degenerate detection, predicate literals, CTEs, functions +# ---------------------------------------------------------------------- + + +def _parse(sql: str): + try: + tree = sqlglot.parse_one(sql, read="duckdb") + except Exception: # noqa: BLE001 + return None + return tree + + +def is_degenerate_sql(sql: str) -> bool: + """True if the SQL is a no-op: constant-false WHERE or constant-only aggregates. + + Catches placeholder queries like ``... WHERE 1=0`` and + ``SUM(CASE WHEN ... THEN 0 ELSE 0 END)`` that execute cleanly but answer + nothing. Structural + schema-agnostic. + """ + tree = _parse(sql) + if tree is None: + return False + + # 1) constant-false WHERE + for where in tree.find_all(exp.Where): + cond = where.this + if cond is None: + continue + try: + from sqlglot.optimizer.simplify import simplify + + simp = simplify(cond.copy()) + if isinstance(simp, exp.Boolean) and simp.this is False: + return True + except Exception: # noqa: BLE001 + pass + if ( + isinstance(cond, exp.EQ) + and isinstance(cond.this, exp.Literal) + and isinstance(cond.expression, exp.Literal) + and cond.this.name != cond.expression.name + ): + return True + + # 2) every projected expression is a constant aggregate / literal + selects = list(tree.find_all(exp.Select)) + if selects: + top = selects[0] + exprs = [e for e in top.expressions] + if exprs and all(_is_constant_projection(e) for e in exprs): + return True + return False + + +def _is_constant_projection(e: exp.Expression) -> bool: + node = e.this if isinstance(e, exp.Alias) else e + if isinstance(node, exp.Literal): + return True + if isinstance(node, exp.AggFunc): + arg = node.this + if isinstance(arg, exp.Literal): + return True + if isinstance(arg, exp.Case) and _case_all_same_literal(arg): + return True + return False + + +def _case_all_same_literal(case_expr: exp.Case) -> bool: + lits: list[exp.Expression] = [] + for ifexpr in case_expr.args.get("ifs", []) or []: + lits.append(ifexpr.args.get("true")) + default = case_expr.args.get("default") + if default is not None: + lits.append(default) + if not lits: + return False + names: set[str] = set() + for lit in lits: + if not isinstance(lit, exp.Literal): + return False + names.add(lit.name) + return len(names) <= 1 + + +def equality_predicate_columns(sql: str) -> list[str]: + """Column names used in ``=`` / ``IN`` / ``LIKE`` predicates (for repair). + + These are the columns whose stored values the critic should re-check when a + query executes but returns 0 rows. + """ + tree = _parse(sql) + if tree is None: + return [] + cols: list[str] = [] + seen: set[str] = set() + for pred in tree.find_all(exp.EQ, exp.In, exp.Like, exp.ILike): + target = pred.this + col = target.find(exp.Column) if target is not None else None + if isinstance(col, exp.Column) and col.name and col.name.lower() not in seen: + seen.add(col.name.lower()) + cols.append(col.name) + return cols + + +def predicate_literals(sql: str) -> dict: + """Return ``{column_name(lower): set(literal strings)}`` for =/IN predicates.""" + tree = _parse(sql) + out: dict[str, set] = {} + if tree is None: + return out + for pred in tree.find_all(exp.EQ, exp.In): + target = pred.this + col = target.find(exp.Column) if target is not None else None + if not isinstance(col, exp.Column) or not col.name: + continue + lits: set = set() + if isinstance(pred, exp.EQ) and isinstance(pred.expression, exp.Literal): + lits.add(pred.expression.name) + elif isinstance(pred, exp.In): + for e in pred.expressions or []: + if isinstance(e, exp.Literal): + lits.add(e.name) + if lits: + out.setdefault(col.name.lower(), set()).update(lits) + return out + + +def group_coverage_gaps(sql: str, col_catalog: list[dict], literals: list[str]) -> list[dict]: + """IN/= predicates that are a STRICT SUBSET of a detected value group (round-2 #7). + + Returns ``[{column, missing}]`` when the SQL listed SOME members of a value + group the question umbrellas but not all. Advisory only. + """ + groups = resolve_value_groups(literals, col_catalog) + if not groups: + return [] + preds = predicate_literals(sql) + gaps: list[dict] = [] + for g in groups: + colname = g["column"].rsplit(".", 1)[-1].lower() + listed = preds.get(colname) + if not listed: + continue + listed_l = {x.lower() for x in listed} + group_vals = {str(v) for v in g["values"]} + group_l = {v.lower() for v in group_vals} + missing = [v for v in group_vals if v.lower() not in listed_l] + if missing and (listed_l & group_l): # listed some-but-not-all + gaps.append({"column": g["column"], "missing": missing}) + return gaps + + +def cte_and_derived_names(sql: str) -> set[str]: + """Names that are CTE aliases or derived-table/subquery aliases (not real tables).""" + tree = _parse(sql) + if tree is None: + return set() + names: set[str] = set() + for cte in tree.find_all(exp.CTE): + if cte.alias: + names.add(cte.alias) + for sub in tree.find_all(exp.Subquery): + if sub.alias: + names.add(sub.alias) + return {n for n in names if n} + + +def referenced_columns(sql: str) -> list[str]: + """Distinct column identifiers referenced anywhere in the SQL.""" + tree = _parse(sql) + if tree is None: + return [] + out: list[str] = [] + seen: set[str] = set() + for c in tree.find_all(exp.Column): + if c.name and c.name.lower() not in seen: + seen.add(c.name.lower()) + out.append(c.name) + return out + + +def select_aliases(sql: str) -> set[str]: + """Aliases defined in the query (SELECT-list + CTE column outputs).""" + tree = _parse(sql) + if tree is None: + return set() + out: set[str] = set() + for a in tree.find_all(exp.Alias): + if a.alias: + out.add(a.alias.lower()) + return out + + +def _flatten_signed(expr, sign: int, out: list) -> None: + if isinstance(expr, exp.Paren): + _flatten_signed(expr.this, sign, out) + elif isinstance(expr, exp.Sub): + _flatten_signed(expr.this, sign, out) + _flatten_signed(expr.expression, -sign, out) + elif isinstance(expr, exp.Add): + _flatten_signed(expr.this, sign, out) + _flatten_signed(expr.expression, sign, out) + else: + out.append((sign, expr)) + + +def _cond_dim_literals(cond): + """From a ``dim = lit`` / ``dim IN (lits)`` condition -> (dim_col, [literals]).""" + if isinstance(cond, exp.EQ): + c = cond.this.find(exp.Column) if cond.this is not None else None + if isinstance(c, exp.Column) and isinstance(cond.expression, exp.Literal): + return c.name, [cond.expression.name] + elif isinstance(cond, exp.In): + c = cond.this.find(exp.Column) if cond.this is not None else None + lits = [e.name for e in (cond.expressions or []) if isinstance(e, exp.Literal)] + if isinstance(c, exp.Column) and lits: + return c.name, lits + return None, None + + +def _literal_value(x): + """Numeric value of a literal possibly wrapped in a unary minus, else None.""" + if isinstance(x, exp.Neg) and isinstance(x.this, exp.Literal): + try: + return -float(x.this.name) + except (TypeError, ValueError): + return None + if isinstance(x, exp.Literal): + try: + return float(x.name) + except (TypeError, ValueError): + return None + return None + + +def _measure_and_sign(expr): + """For a CASE THEN expression, return (sign, measure_col) when it is a measure + column or its negation (``-m``, ``m * -1``, ``(-m)``); else (None, None). + + This is what lets the detector see a CASE branch that NEGATES the measure + (e.g. ``WHEN cost THEN -"FY (Real)"``) -- the same sign double-count expressed + without a SUM(...)-SUM(...) chain. + """ + e = expr + sign = 1 + for _ in range(6): # bounded unwrap + if isinstance(e, exp.Paren): + e = e.this + elif isinstance(e, exp.Neg): + sign = -sign + e = e.this + elif isinstance(e, exp.Mul): + lv = _literal_value(e.expression) + other = e.this + if lv is None: + lv = _literal_value(e.this) + other = e.expression + if lv is None: + break + if lv < 0: + sign = -sign + e = other + else: + break + col = e if isinstance(e, exp.Column) else (e.find(exp.Column) if e is not None else None) + if isinstance(col, exp.Column) and col.name: + return sign, col.name + return None, None + + +def _agg_case_branches(node) -> list[tuple]: + """AGG(CASE WHEN dim (=|IN) lits THEN [±]measure ...) -> list of + (branch_sign, measure_col, dim_col, [lits]) across ALL branches.""" + if isinstance(node, exp.Alias): + node = node.this + if not isinstance(node, exp.AggFunc): + return [] + arg = node.this + if not isinstance(arg, exp.Case): + return [] + out: list[tuple] = [] + for ifx in arg.args.get("ifs") or []: + bsign, mcol = _measure_and_sign(ifx.args.get("true")) + if mcol is None: + continue + dim_col, lits = _cond_dim_literals(ifx.this) + if not dim_col or not lits: + continue + out.append((bsign, mcol, dim_col, lits)) + return out + + +def signed_subtraction_terms(sql: str) -> list[dict]: + """Terms of an additive formula applied over a measure (round-2 #5B). + + Returns ``[{sign, measure_col, dim_col, literals}]`` for measure terms combined + by +/-, covering BOTH shapes the generator uses: + * a ``SUM(...) - SUM(...) - ...`` chain of per-line-item aggregates, and + * a single ``SUM(CASE WHEN rev THEN m WHEN cost THEN -m ...)`` whose branches + negate the measure. + The caller probes whether a SUBTRACTED term's measure is already stored negative + (so subtracting / negating double-counts the sign). + """ + tree = _parse(sql) + if tree is None: + return [] + selects = list(tree.find_all(exp.Select)) + if not selects: + return [] + out: list[dict] = [] + for e in selects[0].expressions: + node = e.this if isinstance(e, exp.Alias) else e + if isinstance(node, (exp.Sub, exp.Add, exp.Paren)): + leaves: list = [] + _flatten_signed(node, 1, leaves) + for chain_sign, leaf in leaves: + for bsign, mcol, dim, lits in _agg_case_branches(leaf): + out.append( + {"sign": chain_sign * bsign, "measure_col": mcol, "dim_col": dim, "literals": lits} + ) + elif isinstance(node, exp.AggFunc): + for bsign, mcol, dim, lits in _agg_case_branches(node): + out.append({"sign": bsign, "measure_col": mcol, "dim_col": dim, "literals": lits}) + return out + + +# Generated SQL must never reach out to the filesystem / external sources. +# The executor only exposes the dataset's parquet tables as views; a model +# that emits read_csv_auto/read_parquet/pg_read_file would bypass that. +_DANGEROUS_FUNCS = { + "read_csv", + "read_csv_auto", + "read_parquet", + "parquet_scan", + "read_json", + "read_json_auto", + "read_ndjson", + "read_ndjson_auto", + "read_text", + "read_blob", + "glob", + "sniff_csv", + "pg_read_file", + "pg_read_binary_file", + "pg_ls_dir", + "load", + "install", + "csv_scan", +} + + +def find_dangerous_functions(sql: str) -> set[str]: + """Return any filesystem/exfiltration table-function names present in the SQL.""" + tree = _parse(sql) + if tree is None: + return set() + hits: set[str] = set() + for fn in tree.find_all(exp.Func, exp.Anonymous): + name = (getattr(fn, "name", "") or "").lower() + if not name and hasattr(fn, "sql_name"): + try: + name = fn.sql_name().lower() + except Exception: # noqa: BLE001 + name = "" + if name in _DANGEROUS_FUNCS: + hits.add(name) + # also catch raw COPY commands + if tree.find(exp.Command) is not None and re.search(r"\bcopy\b", sql, re.IGNORECASE): + hits.add("copy") + return hits diff --git a/src/flyquery/core/services/retrieval/hybrid_retriever.py b/src/flyquery/core/services/retrieval/hybrid_retriever.py index 46e9770..666856e 100644 --- a/src/flyquery/core/services/retrieval/hybrid_retriever.py +++ b/src/flyquery/core/services/retrieval/hybrid_retriever.py @@ -129,9 +129,15 @@ async def retrieve( # bypasses the top-K truncation. full_inventory = await self._index.all_schema_objects(dataset_id, limit=500) + # Per-column value catalogue (distinct/top values, range, semantic_type) + # for the whole dataset -- feeds value-anchored generation + entity + # resolution + the repair loop. Computed at ingest, just read here. + column_catalog = await self._index.column_value_catalog(dataset_id) + return { "schema_objects": schema_objects, "schema_inventory": full_inventory, + "column_catalog": column_catalog, "examples": examples[:top_k_examples], "metrics": metrics[:top_k_metrics], "glossary": glossary, diff --git a/src/flyquery/core/services/retrieval/reranker.py b/src/flyquery/core/services/retrieval/reranker.py index 6c117f7..32e2988 100644 --- a/src/flyquery/core/services/retrieval/reranker.py +++ b/src/flyquery/core/services/retrieval/reranker.py @@ -22,15 +22,22 @@ from __future__ import annotations import logging +import re from typing import Any, Protocol from flyquery.core.services.retrieval.search_index import Hit logger = logging.getLogger(__name__) -# Guard so the "reranking disabled" warning is emitted at most once per process. +# Guard so the cross-encoder-unavailable warning is emitted at most once per process. _warned_noop_fallback = False +_TOKEN = re.compile(r"[\wÀ-ý]+", re.UNICODE) + + +def _tokens(text: str) -> set[str]: + return {t.lower() for t in _TOKEN.findall(text or "") if len(t) >= 2} + class Reranker(Protocol): """Protocol for a reranking step in the retrieval pipeline.""" @@ -44,16 +51,40 @@ class NoopReranker: """Identity reranker that preserves the original order.""" async def rerank(self, query: str, hits: list[Hit], top_n: int) -> list[Hit]: # noqa: ARG002 - """Return the first ``top_n`` hits unchanged. - - :param query: NL query string (unused) - :param hits: candidate hits from the retriever - :param top_n: how many to return - :return: first ``top_n`` elements of ``hits`` - """ + """Return the first ``top_n`` hits unchanged.""" return hits[:top_n] +class LexicalReranker: + """Dependency-free reranker: token-overlap between the query and each hit. + + A real cross-encoder is better, but when ``sentence-transformers`` is not + installed this is a strict improvement over the identity ``NoopReranker``: + it boosts hits whose text (now including the column's indexed VALUES) shares + tokens with the question, blending the lexical score with the retriever's + RRF score so a token-exact value/name match surfaces the owning column. + Fully dataset-agnostic. + """ + + async def rerank(self, query: str, hits: list[Hit], top_n: int) -> list[Hit]: + if not hits: + return [] + qtok = _tokens(query) + if not qtok: + return hits[:top_n] + + def score(h: Hit) -> float: + htok = _tokens(h.text) + if not htok: + return 0.0 + overlap = len(qtok & htok) + lex = overlap / (len(qtok) ** 0.5) + return lex + 0.25 * float(getattr(h, "score", 0.0) or 0.0) + + order = sorted(range(len(hits)), key=lambda i: score(hits[i]), reverse=True) + return [hits[i] for i in order[:top_n]] + + class CrossEncoderReranker: """Reranker backed by a sentence-transformers CrossEncoder model. @@ -83,32 +114,36 @@ async def rerank(self, query: str, hits: list[Hit], top_n: int) -> list[Hit]: return [hits[i] for i in order[:top_n]] -def build_reranker(settings: Any) -> NoopReranker | CrossEncoderReranker: +def build_reranker(settings: Any) -> NoopReranker | LexicalReranker | CrossEncoderReranker: """Factory: return a ``CrossEncoderReranker`` when possible. - Falls back to ``NoopReranker`` when: + Falls back to the dependency-free ``LexicalReranker`` (NOT a no-op) when: - ``settings.reranker_model`` is empty / falsy - ``sentence-transformers`` is not installed - The specified model cannot be loaded (network error, etc.) + The previous default silently degraded to an identity pass-through, leaving + wide-table column precision unimproved; the lexical fallback is a strict win + on any dataset and is what makes value-aware retrieval reach the prompt. + :param settings: ``FlyquerySettings`` instance :return: a ready-to-use reranker """ global _warned_noop_fallback model_name = getattr(settings, "reranker_model", "") or "" if not model_name: - return NoopReranker() + return LexicalReranker() try: return CrossEncoderReranker(model_name) except Exception as exc: # noqa: BLE001 if not _warned_noop_fallback: _warned_noop_fallback = True logger.warning( - "reranker model=%s unavailable (%s) -- falling back to NoopReranker. " - "Relevance reranking is DISABLED; results are truncated by retrieval " - "order only (install sentence-transformers / make the cross-encoder " - "model loadable to enable it).", + "reranker model=%s unavailable (%s) -- falling back to the " + "dependency-free LexicalReranker (token-overlap). Cross-encoder " + "reranking is disabled; install sentence-transformers / make the " + "model loadable to re-enable it.", model_name, exc, ) - return NoopReranker() + return LexicalReranker() diff --git a/src/flyquery/core/services/retrieval/search_index.py b/src/flyquery/core/services/retrieval/search_index.py index beb828b..c94090a 100644 --- a/src/flyquery/core/services/retrieval/search_index.py +++ b/src/flyquery/core/services/retrieval/search_index.py @@ -30,6 +30,8 @@ import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession +from flyquery.core.services.query.value_anchoring import coerce_profile, semantic_role + @dataclass(frozen=True) class Hit: @@ -42,112 +44,101 @@ class Hit: metadata: dict = field(default_factory=dict) -def value_fingerprint( - data_type: str | None, - sample_values_json: object, - profile_json: object, - *, - max_values: int = 40, - max_chars: int = 400, -) -> str: - """Compact, human-readable summary of a column's ACTUAL values. - - Surfaced into the grounding/generation prompts so the agents copy - WHERE / CASE literals verbatim from real values instead of guessing - -- e.g. a fiscal year stored as ``FY23`` (not ``2023``), the members - of a tall/EAV category column (``P&L Line System`` rows like - ``Total Revenue`` / ``Manpower``), or the magnitude gap between a - scaled-duplicate measure (``FY`` ~0.02 vs ``FY (Real)`` ~25748). - - Returns ``""`` when there is nothing useful to show. +def _semantic_type(governance_json) -> str | None: + g = governance_json + if isinstance(g, dict): + return g.get("semantic_type") + return None + + +def _original_header(governance_json) -> str | None: + g = governance_json + if not isinstance(g, dict): + return None + v = g.get("original_header") + if not v: + return None + s = str(v).strip() + # don't surface a header that is just a number (mis-detected data cell), + # unless it is a plausible 4-digit year. + try: + float(s.replace(",", "")) + if not (s.isdigit() and len(s) == 4 and 1900 <= int(s) <= 2100): + return None + except ValueError: + pass + return s or None + + +def _col_value_suffix(profile_json, sample_values_json, semantic_type, *, max_values: int = 25) -> str: + """Compact ``[role, N distinct] values: a, b, ...`` / ``range: min..max`` suffix. + + Appended to a column hit's text so (a) the lexical reranker can match a + question token to a stored value, and (b) the grounding agent sees the real + vocabulary. Dataset-agnostic. """ - prof = profile_json if isinstance(profile_json, dict) else {} - - # NOTE: profiling stores ``subtotal_values`` (name-based candidates for - # pre-aggregated rows). We deliberately do NOT surface them as an - # exclusion directive: a "Total_*" value in a dimension is just as often a - # legitimate additive bucket (e.g. unallocated/corporate) as a true rollup, - # and any hint makes the agent wrongly drop it. Whether to exclude requires - # the structural test (does the value's aggregate == the sum of the others?) - # which is not available per-column at profile time. The general "aggregate - # across ALL values / don't drop a value" prompt rule handles this safely. - subtotal_note = "" - - # Self-referencing hierarchy hint: this column's values are entities from - # another (higher-cardinality) column -- e.g. a manager column whose values - # are people from the employee column. Surfaced even for high-cardinality - # columns that have no listable values, because that is exactly when the - # agent cannot otherwise tell who a person reports to. - ref_col = prof.get("references_column") + prof = coerce_profile(profile_json, sample_values_json, max_values=max_values) + # Self-referencing hierarchy hint: when the profile stage detected that this + # column's values are entities drawn from another (higher-cardinality) column + # -- e.g. a manager column whose values are people from the employee column -- + # surface it so the agent filters THIS column to find a person's reports/team + # instead of the person's own row. Surfaced even for high-cardinality columns + # with no listable values, because that is exactly when the agent otherwise + # cannot tell who reports to whom. + ref_col = profile_json.get("references_column") if isinstance(profile_json, dict) else None ref_note = ( - f" | HIERARCHY: holds entities/people from column '{ref_col}' (e.g. each row's " - f"manager/owner/parent). To get a given person's group/team/reports, filter THIS " - f"column to that person (case-insensitive LIKE), not the person's own row." + f" | HIERARCHY: holds entities from column '{ref_col}' (each row's " + f"manager/owner/parent); to get a person's group/team/reports filter THIS " + f"column to that person (case-insensitive LIKE), not their own row." if ref_col else "" ) - - # Categorical: the stored distinct value set (low-cardinality columns). - top_values = prof.get("top_values") or [] - if top_values: - seen: set[str] = set() - uniq: list[str] = [] - for tv in top_values: - v = tv.get("value") if isinstance(tv, dict) else tv - if v is None: - continue - s = str(v) - if s not in seen: - seen.add(s) - uniq.append(s) - shown = uniq[:max_values] - body = " | ".join(shown) - if len(body) > max_chars: - body = body[:max_chars].rsplit("|", 1)[0].strip() + " | …" - more = "" if len(uniq) <= len(shown) else f" (+{len(uniq) - len(shown)} more)" - return f"values: {body}{more}{subtotal_note}{ref_note}" if body else ref_note.strip(" |") - - # Numeric / temporal: range + cardinality (exposes scaled duplicates). - col_min, col_max = prof.get("min"), prof.get("max") - if col_min is not None or col_max is not None: - rng = f"range: {col_min} .. {col_max}" - dist = prof.get("distinct_estimate") - if dist is not None: - rng += f" (~{dist} distinct)" - return rng + ref_note - - # Fallback: a few raw sample values (high-cardinality columns). - samples = sample_values_json if isinstance(sample_values_json, list) else [] - if samples: - seen2: set[str] = set() - uniq2: list[str] = [] - for v in samples: - s = str(v) - if s not in seen2: - seen2.add(s) - uniq2.append(s) - if uniq2: - return "e.g.: " + " | ".join(uniq2[:8]) + ref_note - return ref_note.strip(" |") - - -def _column_hit(r, score: float) -> Hit: - """Build a ranked schema-object Hit, enriched with a value fingerprint. - - Used by the BM25 + vector column searches so the ranked "Top-ranked - column matches" the grounding agent sees carry real values, not just - name + description. - """ - fp = value_fingerprint(r.data_type, r.sample_values_json, r.profile_json) - text = f"{r.qualified_name}: {r.data_type}\n{r.description or ''}" - if fp: - text += f"\n{fp}" + bits = [] + role = semantic_role(semantic_type) + if role: + bits.append(role) + if prof["distinct"] is not None: + bits.append(f"{prof['distinct']} distinct") + head = f" [{', '.join(bits)}]" if bits else "" + if prof["values"]: + v = ", ".join(prof["values"]) + if len(v) > 320: + v = v[:320].rsplit(",", 1)[0] + ", …" + return f"{head} values: {v}{ref_note}" + if prof.get("min") is not None or prof.get("max") is not None: + line = f"{head} range: {prof.get('min')}..{prof.get('max')}" + if role == "measure" and prof.get("mixed_sign"): + line += " SIGNED(neg+pos)" + return line + ref_note + return head + ref_note + + +def _schema_hit(r, score: float) -> Hit: + """Build a column/table schema_object Hit with value + type metadata.""" + kind = r.get("kind", None) + semantic_type = _semantic_type(r["governance_json"]) if "governance_json" in r else None + suffix = "" + if kind == "COLUMN": + suffix = _col_value_suffix( + r.get("profile_json", None), + r.get("sample_values_json", None), + semantic_type, + ) + text = f"{r['qualified_name']}: {r['data_type'] or ''}\n{r['description'] or ''}{suffix}" return Hit( source_kind="schema_object", - id=r.id, + id=r["id"], text=text, score=score, - metadata={"qualified_name": r.qualified_name, "table_id": str(r.table_id), "values": fp}, + metadata={ + "qualified_name": r["qualified_name"], + "table_id": str(r["table_id"]), + "kind": kind, + "data_type": r["data_type"], + "semantic_type": semantic_type, + "profile_json": r.get("profile_json", None), + "sample_values_json": r.get("sample_values_json", None), + }, ) @@ -168,13 +159,12 @@ async def bm25_schema_objects(self, query: str, dataset_id: uuid.UUID, limit: in rows = await self._session.execute( sa.text( """ - SELECT o.id, o.qualified_name, o.description, o.data_type, o.table_id, - o.sample_values_json, o.profile_json, + SELECT o.id, o.qualified_name, o.description, o.data_type, o.table_id, o.kind, + o.profile_json, o.sample_values_json, o.governance_json, ts_rank(o.content_tsv, plainto_tsquery('english', :q)) AS score FROM flyquery_schema_objects o JOIN flyquery_tables t ON t.id = o.table_id WHERE t.dataset_id = :ds AND o.is_active = true - AND o.snapshot_id = t.current_snapshot_id AND o.content_tsv @@ plainto_tsquery('english', :q) ORDER BY score DESC LIMIT :lim @@ -182,7 +172,7 @@ async def bm25_schema_objects(self, query: str, dataset_id: uuid.UUID, limit: in ), {"q": query, "ds": dataset_id, "lim": limit}, ) - return [_column_hit(r, float(r.score)) for r in rows.mappings()] + return [_schema_hit(r, float(r["score"])) for r in rows.mappings()] async def vector_schema_objects( self, @@ -200,20 +190,19 @@ async def vector_schema_objects( rows = await self._session.execute( sa.text( """ - SELECT o.id, o.qualified_name, o.description, o.data_type, o.table_id, - o.sample_values_json, o.profile_json, + SELECT o.id, o.qualified_name, o.description, o.data_type, o.table_id, o.kind, + o.profile_json, o.sample_values_json, o.governance_json, 1 - (o.embedding <=> CAST(:emb AS vector)) AS score FROM flyquery_schema_objects o JOIN flyquery_tables t ON t.id = o.table_id - WHERE t.dataset_id = :ds AND o.is_active = true - AND o.snapshot_id = t.current_snapshot_id AND o.embedding IS NOT NULL + WHERE t.dataset_id = :ds AND o.is_active = true AND o.embedding IS NOT NULL ORDER BY o.embedding <=> CAST(:emb AS vector) LIMIT :lim """ ), {"emb": str(query_embedding), "ds": dataset_id, "lim": limit}, ) - return [_column_hit(r, float(r.score)) for r in rows.mappings()] + return [_schema_hit(r, float(r["score"])) for r in rows.mappings()] async def all_schema_objects( self, @@ -253,9 +242,7 @@ async def all_schema_objects( ON c.table_id = o.table_id AND c.kind = 'COLUMN' AND c.is_active = true - AND c.snapshot_id = t.current_snapshot_id WHERE t.dataset_id = :ds AND o.is_active = true AND o.kind = 'TABLE' - AND o.snapshot_id = t.current_snapshot_id GROUP BY o.id, o.qualified_name, o.description, o.table_id, o.kind ORDER BY o.qualified_name """ @@ -273,11 +260,10 @@ async def all_schema_objects( sa.text( """ SELECT o.id, o.qualified_name, o.description, o.data_type, o.table_id, o.kind, - o.sample_values_json, o.profile_json + o.profile_json, o.sample_values_json, o.governance_json FROM flyquery_schema_objects o JOIN flyquery_tables t ON t.id = o.table_id WHERE t.dataset_id = :ds AND o.is_active = true AND o.kind = 'COLUMN' - AND o.snapshot_id = t.current_snapshot_id ORDER BY o.qualified_name LIMIT :lim """ @@ -322,25 +308,57 @@ async def all_schema_objects( ) for r in column_rows: - fp = value_fingerprint(r["data_type"], r["sample_values_json"], r["profile_json"]) - text = f"{r['qualified_name']}: {r['data_type'] or ''}\n{r['description'] or ''}" - if fp: - text += f"\n{fp}" - hits.append( - Hit( - source_kind="schema_object", - id=r["id"], - text=text, - score=1.0, - metadata={ - "qualified_name": r["qualified_name"], - "table_id": str(r["table_id"]), - "kind": "COLUMN", - "values": fp, - }, + hits.append(_schema_hit(r, 1.0)) + return hits + + async def column_value_catalog(self, dataset_id: uuid.UUID, *, limit: int = 2000) -> list[dict]: + """Return every column's value catalogue for the dataset. + + Used (a) to resolve question literals to the columns that store them + and (b) to render a complete, value-anchored column list per in-scope + table. Each item: ``{qualified_name, table_id, data_type, semantic_type, + distinct, values, min, max}``. Dataset-agnostic — reads only the + ingest-computed ``profile_json`` / ``sample_values_json`` / + ``governance_json``. + """ + rows = ( + ( + await self._session.execute( + sa.text( + """ + SELECT o.qualified_name, o.data_type, o.table_id, + o.profile_json, o.sample_values_json, o.governance_json + FROM flyquery_schema_objects o + JOIN flyquery_tables t ON t.id = o.table_id + WHERE t.dataset_id = :ds AND o.is_active = true AND o.kind = 'COLUMN' + ORDER BY o.qualified_name + LIMIT :lim + """ + ), + {"ds": dataset_id, "lim": limit}, ) ) - return hits + .mappings() + .all() + ) + out: list[dict] = [] + for r in rows: + prof = coerce_profile(r["profile_json"], r["sample_values_json"], max_values=40) + out.append( + { + "qualified_name": r["qualified_name"], + "table_id": str(r["table_id"]), + "data_type": r["data_type"], + "semantic_type": _semantic_type(r["governance_json"]), + "distinct": prof["distinct"], + "values": prof["values"], + "min": prof["min"], + "max": prof["max"], + "mixed_sign": prof["mixed_sign"], + "original_header": _original_header(r["governance_json"]), + } + ) + return out async def approved_examples( self, @@ -378,7 +396,7 @@ async def approved_examples( ELSE 0.5 END AS score FROM flyquery_examples - WHERE workspace_id = :workspace_id AND quality = 'APPROVED' {ds_filter} + WHERE workspace_id = :workspace_id AND quality IN ('APPROVED', 'PROPOSED') {ds_filter} ORDER BY score DESC LIMIT :lim """ @@ -392,7 +410,7 @@ async def approved_examples( SELECT id, question, generated_sql, 0.5 AS score FROM flyquery_examples - WHERE workspace_id = :workspace_id AND quality = 'APPROVED' {ds_filter} + WHERE workspace_id = :workspace_id AND quality IN ('APPROVED', 'PROPOSED') {ds_filter} ORDER BY created_at DESC LIMIT :lim """ diff --git a/src/flyquery/resources/prompts/critic.yaml b/src/flyquery/resources/prompts/critic.yaml index 10458bd..576346d 100644 --- a/src/flyquery/resources/prompts/critic.yaml +++ b/src/flyquery/resources/prompts/critic.yaml @@ -10,8 +10,36 @@ system: | error message, the complete dataset catalogue (every real table + its column fingerprint), and the GroundedContext. - Produce a corrected SQL. Common errors to fix: + The "Problem" may be a DuckDB execution error, OR the previous SQL + executed but returned 0 ROWS, OR it was a degenerate/no-op query, OR an + ADVISORY warning about an otherwise-runnable query (a SIGN warning, or an + IN-list that under-covers a value group). When a "Column value catalogue" + section is present, it lists the REAL stored values of the columns the + previous SQL filtered on -- use it. + IMPORTANT: if the Problem is a purely advisory note AND you have verified the + previous SQL is already correct, you MAY return it UNCHANGED -- do not rewrite + a correct query gratuitously. BUT a "SIGN ERROR (must fix)" or an under-covered + value group is a REAL defect: you MUST apply the fix exactly as instructed + (e.g. replace the subtraction chain with the single signed-sum given), not keep + the original. + + Produce a corrected SQL. Common issues to fix: + + - **SIGN warning (signed measure)** -- the formula subtracts a term whose + measure is already stored NEGATIVE, double-counting the sign. SUM the + signed line-items instead of subtracting an already-negative term. + - **Under-covered value group** -- a question term umbrellas several values + but the IN-list named only some; include ALL of them or use an anchored LIKE. + - **Wrong filter literal (0 rows)** -- the value in a WHERE/IN does not + match how the data stores it (case, spelling, 'FY23' vs 2023, label vs + code). Replace it with a value COPIED VERBATIM from the value catalogue. + - **Wrong column for an entity (0 rows)** -- the entity is stored under a + different column. Move the filter to the column whose catalogue actually + lists that value. + - **Degenerate / no-op** -- if the previous SQL was a placeholder + (WHERE 1=0, SUM(CASE..THEN 0 ELSE 0), constant SELECT), rewrite it to + really compute the answer; never return another no-op. - **Missing table** -- the SQL references a table that doesn't exist (e.g. ``balance_sheet``, ``income_statement``, ``financials.*``). Look at the catalogue in the prompt and diff --git a/src/flyquery/resources/prompts/explainer.yaml b/src/flyquery/resources/prompts/explainer.yaml index 124cd93..d6eca21 100644 --- a/src/flyquery/resources/prompts/explainer.yaml +++ b/src/flyquery/resources/prompts/explainer.yaml @@ -25,10 +25,10 @@ system: | ---------------------------------- When the result contains columns named ``year_1``, ``year_2``, ..., ``year_N`` (or ``yr1``..``yrN``) with no explicit calendar - year, follow the Excel left-to-right convention: ``year_1`` is the - EARLIEST period and ``year_N`` is the MOST RECENT period. A - question about "the most recent year" maps to the highest-numbered - ``year_N``; "a few years ago" maps to a lower-numbered one. Growth - is ``year_N - year_1`` (NOT the other way around). If you are - unsure, say so explicitly — never assert a direction you don't - have data for. + year, do not assume a fixed direction: different exports order + years differently. If the executed SQL or column metadata shows an + original header (the real year) or aligns to a sibling table's real + calendar-year columns, use that. Only when there is NO such evidence + fall back to the Excel left-to-right convention (``year_1`` earliest, + ``year_N`` most recent). If you are unsure of the direction, say so + explicitly — never assert a chronology you don't have evidence for. diff --git a/src/flyquery/resources/prompts/generation.yaml b/src/flyquery/resources/prompts/generation.yaml index e49b97a..877968f 100644 --- a/src/flyquery/resources/prompts/generation.yaml +++ b/src/flyquery/resources/prompts/generation.yaml @@ -19,6 +19,40 @@ system: | Trust order: SEMANTIC_LAYER > UPLOADED_TABLE. + Value anchoring (CRITICAL) + -------------------------- + The prompt lists, for each in-scope column, its data type, semantic role + (measure / dimension / time) and the ACTUAL stored values (or numeric + range). You MUST anchor to these: + - Every literal in a WHERE / IN / GROUP BY / JOIN must be COPIED VERBATIM + from that column's "values:" list (exact case + spelling). Never invent, + translate, pluralise, or re-case a filter literal. + - If a question entity (a brand, region, business unit, person, status, ...) + appears under a column DIFFERENT from the one you'd expect, filter the + column that actually contains it. The "Resolved entities" section tells you + where located literals live -- trust it. + - For a year/period/time column whose values are strings (e.g. 'FY23'), + filter with the STRING form shown -- never an integer like 2023. Never use + a numeric MEASURE column as the year/time axis. If a column shows an + "original header" (e.g. `year_1 (... original header: '2024')`), THAT is its + real meaning -- use it to decide which column is the most recent or a given year. + - A term shown under "Resolved value groups" is an UMBRELLA over several stored + values of one column -- filter the WHOLE set (IN of every listed value, or an + anchored LIKE when the list is incomplete), never a subset. + - A measure annotated SIGNED stores both positive and negative values (some + line-items are already stored negative). To apply a user additive formula + like Total = (+)A (-)B (-)C over a SIGNED measure, SUM the signed line-items + -- do NOT subtract a term whose stored values are already negative, which + double-counts the sign. Only negate a term confirmed to be stored positive. + - Aggregate measure columns; filter and group on dimension/time columns. + + Never produce a no-op + --------------------- + A placeholder query (WHERE 1=0, SUM(CASE WHEN ... THEN 0 ELSE 0 END), a + constant SELECT) is NEVER a valid candidate. If you cannot map a required + filter to a real stored value, widen or drop that filter and lower the + candidate's confidence -- do not emit a constant/empty query. + If the GroundedContext provides a starting_point_sql, your candidates should be deltas (added WHERE clause, swapped column, etc.) -- do not rewrite from scratch unless necessary. diff --git a/src/flyquery/resources/prompts/grounding.yaml b/src/flyquery/resources/prompts/grounding.yaml index e2000b0..9bdf778 100644 --- a/src/flyquery/resources/prompts/grounding.yaml +++ b/src/flyquery/resources/prompts/grounding.yaml @@ -64,14 +64,27 @@ system: | 4. **Never invent.** Never output a qualified_name that doesn't appear in the catalogue or top-ranked columns. If nothing matches, set ``confidence`` low and explain in ``missing_info``. - 5. **Ordinal-year chronology.** When a table has columns named - ``year_1``, ``year_2``, ..., ``year_N`` (or ``yr1``..``yrN``) - with no explicit calendar year, follow the Excel left-to-right - convention: ``year_1`` is the EARLIEST period and ``year_N`` - is the MOST RECENT period. A question about "the most recent - year" picks the highest-numbered column; "a few years ago" - picks a lower-numbered one. Do not swap this direction in - either the SELECT projection or any WHERE/ORDER BY clauses. + 5. **Ordinal-year chronology -- consult evidence FIRST.** When a table has + columns like ``year_1``..``year_N`` (or ``yr1``..``yrN``) with no explicit + calendar year, decide which is most-recent in this order: + (a) if a column shows an "original header" (e.g. + ``year_1 (... original header: '2024')``), that header IS its real period -- + use it and never contradict it; + (b) else if a SIBLING table in the same dataset carries real calendar-year + columns (e.g. ``year_2024``..``year_2015``), align the ordinal columns to + that ordering; + (c) ONLY if neither exists, fall back as a LAST RESORT to the Excel + left-to-right convention (``year_1`` = earliest, ``year_N`` = most recent) + and lower your confidence. Different exports order years differently, so + never assume the direction when evidence is available. + 6. **Use real values + roles.** Column entries may carry their stored + VALUES (e.g. ``values: 'FY23','FY24'``) and a role tag + (``measure`` / ``dimension`` / ``time``). When the question names an + entity, prefer the column whose listed values actually contain it + (it may not be the obviously-named column). Treat measures as the + things to aggregate and dimensions/time as the things to filter and + group by. Include in ``columns`` every column needed to filter the + question's entities, not just the projected measure. Drill-down ---------- diff --git a/tests/integration/test_reranker.py b/tests/integration/test_reranker.py index 1b29c9e..d4160c3 100644 --- a/tests/integration/test_reranker.py +++ b/tests/integration/test_reranker.py @@ -20,7 +20,7 @@ import pytest -from flyquery.core.services.retrieval.reranker import NoopReranker, build_reranker +from flyquery.core.services.retrieval.reranker import LexicalReranker, NoopReranker, build_reranker from flyquery.core.services.retrieval.search_index import Hit @@ -54,23 +54,25 @@ async def test_noop_reranker_handles_empty() -> None: @pytest.mark.asyncio -async def test_build_reranker_returns_noop_when_no_model() -> None: +async def test_build_reranker_returns_lexical_when_no_model() -> None: class FakeSettings: reranker_model = "" r = build_reranker(FakeSettings()) - assert isinstance(r, NoopReranker) + # No model configured -> the dependency-free LexicalReranker (a strict win + # over the identity NoopReranker, which left wide-table precision unimproved). + assert isinstance(r, LexicalReranker) @pytest.mark.asyncio -async def test_build_reranker_falls_back_on_bad_model() -> None: +async def test_build_reranker_falls_back_to_lexical_on_bad_model() -> None: class FakeSettings: reranker_model = "nonexistent/model-that-does-not-exist" r = build_reranker(FakeSettings()) - # Should fall back silently to Noop when sentence-transformers isn't installed - # or the model can't be loaded. - assert isinstance(r, NoopReranker) + # Falls back to LexicalReranker when sentence-transformers isn't installed or + # the cross-encoder model can't be loaded (token-overlap still beats no-op). + assert isinstance(r, LexicalReranker) @pytest.mark.asyncio diff --git a/tests/unit/test_query_service.py b/tests/unit/test_query_service.py index 6cd6167..f036884 100644 --- a/tests/unit/test_query_service.py +++ b/tests/unit/test_query_service.py @@ -116,6 +116,28 @@ def __iter__(self): return iter(self._rows) +class _TableAwareSession: + """Session whose table-kind lookup reports one TABLE so a realistic + ``... FROM `` query clears the scope + unknown-table guards.""" + + def __init__(self, table_name: str): + self._table_name = table_name + + async def execute(self, stmt, params=None): + return _FakeMappingResult([{"name": self._table_name, "kind": "TABLE"}]) + + +class _TableAwareResolver: + def __init__(self, table_name: str = "t"): + self._session = _TableAwareSession(table_name) + + async def resolve(self, dataset_id, table_names, object_store_base=None, pins=None): + return {} + + async def current_snapshots(self, dataset_id, table_names): + return {} + + class _FakeQueryRepo: """In-memory query repo stub.""" @@ -161,6 +183,21 @@ class _FakeSettings: top_k_metrics = 8 max_refine_retries = 2 grounding_min_confidence = 0.55 + generation_candidates = 3 + # Value-anchoring knobs: disabled here so these orchestration tests exercise + # the core answer() flow (call order / persistence / auto-learn). The + # value-anchoring paths need a real dataset value-scan and are covered by the + # value_anchoring unit tests + integration/e2e. + value_catalog_char_budget = 320 + value_catalog_max_columns = 80 + entity_resolution_enabled = False + entity_resolution_max_literals = 8 + zero_row_repair_enabled = False + candidate_exec_selection = False + synthesis_function_firewall = False + group_resolution_enabled = False + signed_measure_repair_enabled = False + group_coverage_repair_enabled = False def _make_grounded(confidence: float = 0.9, path: str = "SYNTHESIS") -> GroundedContext: @@ -219,6 +256,8 @@ def _make_service( auto_learner=None, semantic_repo=None, generation_agent=None, + critic_agent=None, + table_resolver=None, ): if grounded is None: grounded = _make_grounded() @@ -236,6 +275,10 @@ def _make_service( auto_learner = _FakeAutoLearner() if generation_agent is None: generation_agent = _FakeAgent(candidates) + if critic_agent is None: + critic_agent = _FakeAgent(RefinedSql(sql="SELECT 2", reasoning="fixed", confidence=0.8)) + if table_resolver is None: + table_resolver = _FakeTableResolver() explanation = ResultExplanation(summary="The answer is 1.", chart_hint="none") @@ -244,11 +287,11 @@ def _make_service( reranker=_FakeReranker(), grounding_agent=_FakeAgent(grounded), generation_agent=generation_agent, - critic_agent=_FakeAgent(RefinedSql(sql="SELECT 2", reasoning="fixed", confidence=0.8)), + critic_agent=critic_agent, explainer_agent=_FakeAgent(explanation), ast_classifier=AstClassifier(), scope_guard=ScopeGuard(), - table_resolver=_FakeTableResolver(), + table_resolver=table_resolver, executor=executor, query_repo=query_repo, settings=_FakeSettings(), @@ -340,8 +383,10 @@ async def test_answer_calls_auto_learner_on_first_shot(): @pytest.mark.asyncio async def test_answer_retries_on_execution_error(): - """On ExecutionError, the CriticAgent is called and the result is REFINED_OK or FAILED.""" - # First call fails; critic returns "SELECT 2"; second call also fails → FAILED + """On ExecutionError the CriticAgent is called, its refined (non-degenerate) + SQL is re-executed, and the retry is counted (REFINED_OK or FAILED).""" + # First execution fails; the critic returns a different, runnable SQL; the + # second execution succeeds -> one counted retry, REFINED_OK. call_count = [0] class _AlternatingExecutor: @@ -352,7 +397,15 @@ async def execute(self, sql, attached_tables): return ExecutionResult(rows=[{"v": 2}], columns=["v"], row_count=1, truncated=False) repo = _FakeQueryRepo() - svc = _make_service(executor=_AlternatingExecutor(), query_repo=repo) + svc = _make_service( + executor=_AlternatingExecutor(), + query_repo=repo, + candidates=_make_candidates("SELECT v FROM t"), + critic_agent=_FakeAgent( + RefinedSql(sql="SELECT v FROM t WHERE v > 0", reasoning="fixed", confidence=0.8) + ), + table_resolver=_TableAwareResolver(), + ) result = await svc.answer( tenant_id="ten-a",