Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions skill/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ single DML statement), then **rolls back** — nothing changes. Preview first,
show the human the affected count/sample, then re-run with `--commit` to apply.
`--commit` is the only flag that writes; you can't forget to preview.

The statement executes **exactly once** under both flags. A dry run is a real
trial run, so a statement that *can't* succeed (unique/constraint violation,
type error) **fails the preview** — exit `10` with `ROLLED BACK — nothing
changed`, not a clean `DRY RUN` report. That's a feature: the preview tells you
the write would have failed before you reach `--commit`. Likewise a `--commit`
whose statement aborts is reported honestly — exit `10`, `ROLLED BACK`, never a
false `COMMITTED` — including deferred-constraint violations that only surface at
commit time.

Writes that hide inside a query are caught, not just bare `INSERT`/`UPDATE`/
`DELETE`: writable CTEs (`WITH d AS (DELETE … RETURNING *) SELECT * FROM d`),
leading-CTE DML, `SELECT … INTO new_table`, and `EXPLAIN ANALYZE <write>` (which
Expand Down
150 changes: 116 additions & 34 deletions src/commands/sql_cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,18 +247,31 @@ async fn run_write(
.await
.context("begin transaction")?;

// From here on, ensure we always close the transaction even on error.
// From here on, ensure we always close the transaction even on error. Any
// error inside the transaction must roll back and propagate loudly — a
// swallowed abort that still reports success is the worst failure for a
// trust tool (see PGC-104).
let outcome = execute_in_tx(client, sql, parsed, opts).await;

let (results, rows_affected) = match outcome {
Ok(v) => v,
Err(e) => {
let _ = client.batch_execute("ROLLBACK").await;
return Err(e);
return Err(rolled_back(e));
}
};

if opts.commit {
// Force any DEFERRED constraint checks to run now, while we can still
// observe a violation as an error. Without this, a deferred violation
// surfaces only at COMMIT, where Postgres silently downgrades COMMIT to
// ROLLBACK and returns the `ROLLBACK` tag with no protocol error — the
// exact shape that let an aborted transaction report COMMITTED.
if let Err(e) = client.simple_query("SET CONSTRAINTS ALL IMMEDIATE").await {
let _ = client.batch_execute("ROLLBACK").await;
return Err(rolled_back(anyhow::Error::new(e)));
}

client
.batch_execute("COMMIT")
.await
Expand All @@ -275,62 +288,131 @@ async fn run_write(
Ok(())
}

/// Run the write inside the open transaction, gathering command tags and (when
/// safe) a sample of affected rows via a RETURNING wrapper.
/// Wrap a transaction error in the explicit, human-facing "ROLLED BACK —
/// nothing changed" framing. The chained source carries the original DB error;
/// `main` propagates this to exit code 10 for both human and JSON output.
fn rolled_back(e: anyhow::Error) -> anyhow::Error {
e.context("ROLLED BACK — nothing changed")
}

/// Run the write inside the open transaction **exactly once**, gathering the
/// affected count and (when safe) a sample of affected rows.
///
/// The single-DML-without-RETURNING case is the dangerous one: previously it
/// ran the bare statement for the count *and* a RETURNING-wrapped copy for the
/// sample, double-applying every effect in the same transaction (PGC-104). Now
/// the wrap is the *only* execution — it yields both count and sample in one
/// pass — so each user statement executes once.
async fn execute_in_tx(
client: &Client,
sql: &str,
parsed: &ParsedSql,
opts: &SqlOptions,
) -> Result<(Vec<SqlResult>, u64)> {
let messages = client.simple_query(sql).await.context("execute SQL")?;
let mut results = collect_results(messages, usize::MAX);
let rows_affected = total_affected(&results);

// Sample rows only when it's safe and worthwhile: a single DML statement
// without its own RETURNING. We wrap it so the rows would-be-affected can
// be shown. The wrapper still rolls back with the outer transaction.
if parsed.single_dml_no_returning && rows_affected > 0 {
if let Some(sample) = sample_affected(client, sql, opts).await? {
results.push(sample);
}
// Single DML without its own RETURNING: one execution that returns both the
// true affected count and a bounded sample of affected rows.
if parsed.single_dml_no_returning {
return execute_dml_with_sample(client, sql, opts).await;
}

// Everything else (DML with RETURNING, DDL, multi-statement, writable CTEs):
// one plain execution. No wrapping, no second pass.
let messages = client.simple_query(sql).await.context("execute SQL")?;
let results = collect_results(messages, usize::MAX);
let rows_affected = total_affected(&results);
Ok((results, rows_affected))
}

/// Wrap a single DML statement to capture a few affected rows for preview.
/// Returns `None` (silently) if wrapping fails — sampling is best-effort.
async fn sample_affected(
/// Execute a single INSERT/UPDATE/DELETE (no RETURNING) exactly once, deriving
/// the affected count and a bounded sample from one CTE-wrapped query.
///
/// `WITH x AS (<stmt> RETURNING *) SELECT (SELECT count(*) FROM x) AS affected,
/// (SELECT json_agg(t) FROM (SELECT * FROM x LIMIT N) t) AS samples`
///
/// The statement runs once inside the writable CTE; `count(*)` over `x` gives
/// the exact affected count and `json_agg` over a bounded slice gives the
/// sample. Any DB error (e.g. a unique violation) propagates from this one
/// call, so an aborted transaction can never be reported as COMMITTED.
async fn execute_dml_with_sample(
client: &Client,
sql: &str,
opts: &SqlOptions,
) -> Result<Option<SqlResult>> {
) -> Result<(Vec<SqlResult>, u64)> {
let stmt = sql.trim().trim_end_matches(';');
let limit = match opts.limit {
Some(0) => DRY_RUN_SAMPLE_LIMIT, // uncapped reads still bound the preview
Some(n) => n.min(DRY_RUN_SAMPLE_LIMIT),
None => DRY_RUN_SAMPLE_LIMIT,
};
// `row_to_json` (text `json`, not `jsonb`) preserves the statement's column
// order; `jsonb` would re-sort keys and scramble the preview. `json_agg`
// collects the slice; the keys array carries the column order explicitly so
// Rust never has to infer it from a sorted map.
let wrapped = format!(
"WITH __pgcrate_preview AS ({stmt} RETURNING *) \
SELECT * FROM __pgcrate_preview LIMIT {limit}"
"WITH __pgcrate_w AS ({stmt} RETURNING *), \
__pgcrate_s AS (SELECT * FROM __pgcrate_w LIMIT {limit}), \
__pgcrate_first AS (SELECT * FROM __pgcrate_s LIMIT 1) \
SELECT (SELECT count(*) FROM __pgcrate_w) AS __pgcrate_affected, \
(SELECT json_agg(row_to_json(__s)) FROM __pgcrate_s __s) AS __pgcrate_samples, \
(SELECT array_agg(k) FROM ( \
SELECT json_object_keys(row_to_json(__f)) AS k \
FROM __pgcrate_first __f \
) __k) AS __pgcrate_columns"
);

match client.simple_query(&wrapped).await {
Ok(messages) => {
let collected = collect_results(messages, limit);
// The wrapped statement produces exactly one query result.
let found = collected.into_iter().find_map(|r| match r {
SqlResult::Query { columns, rows, .. } => Some((columns, rows)),
_ => None,
});
Ok(found.map(|(columns, rows)| SqlResult::Sample { columns, rows }))
}
// Wrapping isn't always valid (e.g. statement already errors, or a
// construct RETURNING can't express). Fall back to count-only.
Err(_) => Ok(None),
let row = client
.query_one(&wrapped, &[])
.await
.context("execute SQL")?;

let affected: i64 = row.get("__pgcrate_affected");
let rows_affected = affected.max(0) as u64;

let columns: Option<Vec<String>> = row.get("__pgcrate_columns");

let mut results = vec![SqlResult::CommandComplete {
rows: rows_affected,
}];
if let Some(sample) = sample_from_json(row.get("__pgcrate_samples"), columns) {
results.push(sample);
}
Ok((results, rows_affected))
}

/// Turn the `json_agg(row_to_json(...))` sample column into a `Sample` result.
/// `columns` carries the statement's column order (from `array_agg` of
/// `json_object_keys`); cells are stringified to match the simple-query text
/// representation used elsewhere.
fn sample_from_json(
samples: Option<serde_json::Value>,
columns: Option<Vec<String>>,
) -> Option<SqlResult> {
let rows = samples?.as_array()?.clone();
// Prefer the SQL-provided order; fall back to the first row's keys only if
// it's somehow absent (e.g. all-NULL column array).
let columns = columns
.filter(|c| !c.is_empty())
.or_else(|| Some(rows.first()?.as_object()?.keys().cloned().collect()))?;

let table_rows: Vec<Vec<Option<String>>> = rows
.iter()
.filter_map(|r| r.as_object())
.map(|obj| {
columns
.iter()
.map(|col| match obj.get(col) {
None | Some(serde_json::Value::Null) => None,
Some(serde_json::Value::String(s)) => Some(s.clone()),
Some(other) => Some(other.to_string()),
})
.collect()
})
.collect();

Some(SqlResult::Sample {
columns,
rows: table_rows,
})
}

/// EXPLAIN the input and return the estimated total cost (relative to the
Expand Down
Loading
Loading