From 265f85052b18c6e1355c46d98869657d1d57a447 Mon Sep 17 00:00:00 2001 From: Apoorva Verma Date: Thu, 2 Jul 2026 03:19:13 +0530 Subject: [PATCH] Preserve CHECK constraints when transforming a table .transform() rebuilds from PRAGMA metadata, which doesn't expose CHECK constraints, so they were silently dropped. Pull them from the schema and re-add them to the rebuilt table. --- docs/python-api.rst | 2 + sqlite_utils/db.py | 162 +++++++++++++++++++++++++++++++++++++++- tests/test_transform.py | 142 ++++++++++++++++++++++++++++++++++- 3 files changed, 304 insertions(+), 2 deletions(-) diff --git a/docs/python-api.rst b/docs/python-api.rst index eab858ee7..8ced292dc 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -1505,6 +1505,8 @@ To keep the original table around instead of dropping it, pass the ``keep_table= table.transform(types={"age": int}, keep_table="original_table") +``CHECK`` constraints are preserved across a transform. If a column referenced by a constraint is renamed the constraint is updated to match, and if such a column is dropped the constraint is dropped along with it. + This method raises a ``sqlite_utils.db.TransformError`` exception if the table cannot be transformed, usually because there are existing constraints or indexes that are incompatible with modifications to the columns. .. _python_api_transform_alter_column_types: diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index ae99322b5..962b72e8f 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -176,6 +176,148 @@ class TransformError(Exception): pass +def _tokenize_sql(sql: str) -> List[Tuple[str, str]]: + # Split SQL into (kind, text) tokens. Enough to walk a CREATE TABLE + # statement while respecting string literals, quoted identifiers, comments + # and nesting - not a full parser. + tokens = [] + i, n = 0, len(sql) + while i < n: + c = sql[i] + if c in " \t\r\n": + j = i + 1 + while j < n and sql[j] in " \t\r\n": + j += 1 + tokens.append(("ws", sql[i:j])) + elif sql[i : i + 2] == "--": + j = sql.find("\n", i) + j = n if j == -1 else j + tokens.append(("comment", sql[i:j])) + elif sql[i : i + 2] == "/*": + j = sql.find("*/", i + 2) + j = n if j == -1 else j + 2 + tokens.append(("comment", sql[i:j])) + elif c in "'\"`": + j = i + 1 + while j < n: + if sql[j] == c: + if sql[j : j + 2] == c + c: + j += 2 + continue + j += 1 + break + j += 1 + tokens.append(("string" if c == "'" else "quoted", sql[i:j])) + elif c == "[": + j = sql.find("]", i + 1) + j = n if j == -1 else j + 1 + tokens.append(("quoted", sql[i:j])) + elif c.isalnum() or c in "_$": + j = i + 1 + while j < n and (sql[j].isalnum() or sql[j] in "_$"): + j += 1 + tokens.append(("word", sql[i:j])) + else: + j = i + 1 + tokens.append(("punct", c)) + i = j + return tokens + + +def _capture_paren_inner( + tokens: List[Tuple[str, str]], open_index: int +) -> Tuple[str, int]: + # tokens[open_index] is the opening "(" - return the text inside the + # matching parentheses and the index just past the closing ")". + depth = 0 + parts = [] + i, n = open_index, len(tokens) + while i < n: + kind, text = tokens[i] + if kind == "punct" and text == "(": + depth += 1 + if depth > 1: + parts.append(text) + elif kind == "punct" and text == ")": + depth -= 1 + if depth == 0: + return "".join(parts), i + 1 + parts.append(text) + else: + parts.append(text) + i += 1 + return "".join(parts), i + + +def _extract_check_constraints(create_table_sql: str) -> List[str]: + # CHECK constraints (column-level and table-level) live only in the stored + # CREATE TABLE SQL, not in any PRAGMA. Return the expression inside each one. + # Every CHECK keyword sits at the top level of the table body regardless of + # whether it is attached to a column or the table. + tokens = _tokenize_sql(create_table_sql) + checks = [] + depth = 0 + started = False + i, n = 0, len(tokens) + while i < n: + kind, text = tokens[i] + if kind == "punct" and text == "(": + depth += 1 + started = True + elif kind == "punct" and text == ")": + depth -= 1 + if started and depth == 0: + break + elif started and depth == 1 and kind == "word" and text.upper() == "CHECK": + j = i + 1 + while j < n and tokens[j][0] in ("ws", "comment"): + j += 1 + if j < n and tokens[j] == ("punct", "("): + inner, after = _capture_paren_inner(tokens, j) + checks.append(inner.strip()) + i = after + continue + i += 1 + return checks + + +def _unquote_identifier(text: str) -> str: + if len(text) >= 2: + if text[0] == '"' and text[-1] == '"': + return text[1:-1].replace('""', '"') + if text[0] == "`" and text[-1] == "`": + return text[1:-1].replace("``", "`") + if text[0] == "[" and text[-1] == "]": + return text[1:-1] + return text + + +def _rewrite_check_expression( + expression: str, rename: Dict[str, str], drop: Set[str] +) -> Optional[str]: + # Apply column renames to identifiers referenced by a CHECK expression. + # Returns None if the expression references a dropped column, in which case + # the constraint can no longer be enforced and should be discarded rather + # than producing a table that fails to build. + rename_lower = {k.lower(): v for k, v in rename.items()} + drop_lower = {d.lower() for d in drop} + out = [] + for kind, text in _tokenize_sql(expression): + key = None + if kind == "word": + key = text.lower() + elif kind == "quoted": + key = _unquote_identifier(text).lower() + if key is not None: + if key in drop_lower: + return None + if key in rename_lower: + out.append(quote_identifier(rename_lower[key])) + continue + out.append(text) + return "".join(out) + + ForeignKeyIndicator = Union[ str, ForeignKey, @@ -977,6 +1119,7 @@ def create_table_sql( extracts: Optional[Union[Dict[str, str], List[str]]] = None, if_not_exists: bool = False, strict: bool = False, + check_constraints: Optional[List[str]] = None, ) -> str: """ Returns the SQL ``CREATE TABLE`` statement for creating the specified table. @@ -993,6 +1136,7 @@ def create_table_sql( :param extracts: List or dictionary of columns to be extracted during inserts, see :ref:`python_api_extracts` :param if_not_exists: Use ``CREATE TABLE IF NOT EXISTS`` :param strict: Apply STRICT mode to table + :param check_constraints: List of ``CHECK`` constraint expressions to add as table-level constraints, for example ``["age >= 0"]`` """ if hash_id_columns and (hash_id is None): hash_id = "id" @@ -1094,15 +1238,21 @@ def sort_key(p): extra_pk = ",\n PRIMARY KEY ({pks})".format( pks=", ".join([quote_identifier(p) for p in pk]) ) + extra_checks = "" + if check_constraints: + extra_checks = "".join( + ",\n CHECK ({})".format(check) for check in check_constraints + ) columns_sql = ",\n".join(column_defs) sql = """CREATE TABLE {if_not_exists}{table} ( -{columns_sql}{extra_pk} +{columns_sql}{extra_pk}{extra_checks} ){strict}; """.format( if_not_exists="IF NOT EXISTS " if if_not_exists else "", table=quote_identifier(name), columns_sql=columns_sql, extra_pk=extra_pk, + extra_checks=extra_checks, strict=" STRICT" if strict and self.supports_strict else "", ) return sql @@ -2136,6 +2286,15 @@ def transform_sql( if column_order is not None: column_order = [rename.get(col) or col for col in column_order] + # CHECK constraints are not exposed by any PRAGMA, so pull them out of the + # stored schema and carry them across, applying any renames and dropping + # constraints that reference a removed column. + create_table_checks = [] + for check in _extract_check_constraints(self.schema): + rewritten = _rewrite_check_expression(check, rename, set(drop)) + if rewritten is not None: + create_table_checks.append(rewritten) + sqls = [] sqls.append( self.db.create_table_sql( @@ -2147,6 +2306,7 @@ def transform_sql( foreign_keys=create_table_foreign_keys, column_order=column_order, strict=self.strict, + check_constraints=create_table_checks or None, ).strip() ) diff --git a/tests/test_transform.py b/tests/test_transform.py index 5eb501db5..792f8c4e2 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -1,4 +1,9 @@ -from sqlite_utils.db import ForeignKey, TransformError +from sqlite_utils.db import ( + ForeignKey, + TransformError, + _extract_check_constraints, + _rewrite_check_expression, +) from sqlite_utils.utils import OperationalError import pytest @@ -659,3 +664,138 @@ def test_transform_with_unique_constraint_implicit_index(fresh_db): "You must manually drop this index prior to running this transformation and manually recreate the new index after running this transformation." in str(excinfo.value) ) + + +def test_transform_preserves_check_constraints(fresh_db): + fresh_db.execute( + "CREATE TABLE dogs (\n" + " id integer primary key,\n" + " age integer CHECK (age >= 0),\n" + " name text CHECK (length(name) > 0),\n" + " CHECK (age < 200)\n" + ")" + ) + dogs = fresh_db["dogs"] + assert dogs.transform_sql(tmp_suffix="suffix")[0] == ( + 'CREATE TABLE "dogs_new_suffix" (\n' + ' "id" INTEGER PRIMARY KEY,\n' + ' "age" INTEGER,\n' + ' "name" TEXT,\n' + " CHECK (age >= 0),\n" + " CHECK (length(name) > 0),\n" + " CHECK (age < 200)\n" + ");" + ) + dogs.transform() + # Constraints must still be enforced after the transform + with pytest.raises(Exception): + dogs.insert({"id": 1, "age": -1, "name": "Cleo"}) + with pytest.raises(Exception): + dogs.insert({"id": 2, "age": 5, "name": ""}) + dogs.insert({"id": 3, "age": 5, "name": "Cleo"}) + assert dogs.count == 1 + + +def test_transform_rewrites_check_constraint_on_rename(fresh_db): + fresh_db.execute( + "CREATE TABLE dogs (id integer primary key, age integer CHECK (age >= 0))" + ) + dogs = fresh_db["dogs"] + dogs.transform(rename={"age": "years"}) + assert 'CHECK ("years" >= 0)' in dogs.schema + # The renamed constraint is still enforced, on the new column name + with pytest.raises(Exception): + dogs.insert({"id": 1, "years": -1}) + dogs.insert({"id": 2, "years": 4}) + assert dogs.count == 1 + + +def test_transform_rewrites_check_constraint_to_name_needing_quotes(fresh_db): + # Renaming a checked column to a reserved word or a name with a space must + # still produce valid SQL, just like a plain rename does + fresh_db.execute( + "CREATE TABLE dogs (id integer primary key, age integer CHECK (age >= 0))" + ) + dogs = fresh_db["dogs"] + dogs.transform(rename={"age": "order by"}) + assert 'CHECK ("order by" >= 0)' in dogs.schema + with pytest.raises(Exception): + dogs.insert({"id": 1, "order by": -1}) + dogs.insert({"id": 2, "order by": 4}) + assert dogs.count == 1 + + +def test_transform_drops_check_constraint_for_dropped_column(fresh_db): + fresh_db.execute( + "CREATE TABLE dogs (id integer primary key, age integer CHECK (age >= 0), " + "name text)" + ) + dogs = fresh_db["dogs"] + # Dropping a column referenced by a CHECK drops the constraint rather than + # producing a table that fails to build + dogs.transform(drop=["age"]) + assert "CHECK" not in dogs.schema + assert dogs.columns_dict == {"id": int, "name": str} + + +def test_transform_check_constraints_are_idempotent(fresh_db): + fresh_db.execute( + "CREATE TABLE dogs (id integer primary key, age integer CHECK (age >= 0))" + ) + dogs = fresh_db["dogs"] + dogs.transform() + once = dogs.schema + dogs.transform() + assert dogs.schema == once + assert "CHECK (age >= 0)" in once + + +@pytest.mark.parametrize( + "sql,expected", + [ + ( + "CREATE TABLE t (id integer, age integer CHECK (age >= 0), " + "CHECK (age < 200))", + ["age >= 0", "age < 200"], + ), + # Nested parentheses inside the expression + ("CREATE TABLE t (a int CHECK (a > 0 AND (b < 10)))", ["a > 0 AND (b < 10)"]), + # A comma and the word CHECK inside a string literal must not confuse it + ( + "CREATE TABLE t (name text CHECK (name != 'a,b CHECK ('))", + ["name != 'a,b CHECK ('"], + ), + # CONSTRAINT name form, and no whitespace before the parenthesis + ("CREATE TABLE t (a int CONSTRAINT positive CHECK(a>0))", ["a>0"]), + # A DEFAULT expression in parentheses is not a CHECK + ("CREATE TABLE t (a int DEFAULT (1 + 2), b int CHECK (b > 0))", ["b > 0"]), + ("CREATE TABLE t (a int, b text)", []), + ], +) +def test_extract_check_constraints(sql, expected): + assert _extract_check_constraints(sql) == expected + + +@pytest.mark.parametrize( + "expression,rename,drop,expected", + [ + ("age >= 0", {"age": "years"}, set(), '"years" >= 0'), + # A substring match must not be rewritten + ( + "agent = 1 AND age > 0", + {"age": "years"}, + set(), + 'agent = 1 AND "years" > 0', + ), + ("length(name) > 0", {"name": "full_name"}, set(), 'length("full_name") > 0'), + # A rename to a name needing quoting stays valid + ("age >= 0", {"age": "order by"}, set(), '"order by" >= 0'), + # Identifiers inside string literals are left alone + ("x != 'age'", {"age": "years"}, set(), "x != 'age'"), + # Referencing a dropped column removes the constraint entirely + ("age >= 0", {}, {"age"}, None), + ("age >= 0", {}, {"weight"}, "age >= 0"), + ], +) +def test_rewrite_check_expression(expression, rename, drop, expected): + assert _rewrite_check_expression(expression, rename, drop) == expected