diff --git a/docs/changelog.rst b/docs/changelog.rst index d5c41aa72..1fa3ac8c6 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -9,6 +9,7 @@ Unreleased ---------- +- New ``db.atomic()`` context manager for transactions, with nested transaction support using SQLite savepoints. Internal multi-step operations such as ``table.transform()`` now use this mechanism to avoid unexpectedly committing an existing transaction. - New :ref:`database migrations system `, incorporating functionality that was previously provided by the separate `sqlite-migrate `__ plugin. Define migration sets using the new :class:`sqlite_utils.Migrations` class and apply them using the ``sqlite-utils migrate`` command or the :ref:`migrations Python API `. (:issue:`752`) .. _v3_39: diff --git a/docs/python-api.rst b/docs/python-api.rst index 267591acb..eab858ee7 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -239,6 +239,35 @@ The ``db.execute()`` and ``db.executescript()`` methods provide wrappers around Other cursor methods such as ``.fetchone()`` and ``.fetchall()`` are also available, see the `standard library documentation `__. +.. _python_api_atomic: + +Transactions with db.atomic() +----------------------------- + +Use ``db.atomic()`` to group multiple operations in a transaction: + +.. code-block:: python + + with db.atomic(): + db.table("dogs").insert({"id": 1, "name": "Cleo"}, pk="id") + db.table("dogs").insert({"id": 2, "name": "Pancakes"}) + +If an exception is raised, changes made inside the block will be rolled back. + +``db.atomic()`` can be nested. Nested blocks use SQLite savepoints, so an exception in an inner block can roll back to that savepoint without rolling back the entire outer transaction: + +.. code-block:: python + + with db.atomic(): + db.table("dogs").insert({"id": 1, "name": "Cleo"}, pk="id") + try: + with db.atomic(): + db.table("dogs").insert({"id": 2, "name": "Pancakes"}) + raise ValueError("skip this one") + except ValueError: + pass + db.table("dogs").insert({"id": 3, "name": "Marnie"}) + .. _python_api_parameters: Passing parameters @@ -955,7 +984,7 @@ You can delete all records in a table that match a specific WHERE statement usin >>> db = sqlite_utils.Database("dogs.db") >>> # Delete every dog with age less than 3 - >>> with db.conn: + >>> with db.atomic(): >>> db.table("dogs").delete_where("age < ?", [3]) Calling ``table.delete_where()`` with no other arguments will delete every row in the table. diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index ed3fc7af2..677a1f878 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -196,6 +196,19 @@ class Default: Tracer = Callable[[str, Optional[Union[Sequence[Any], Dict[str, Any]]]], None] +def _iter_complete_sql_statements(sql: str) -> Generator[str, None, None]: + statement = [] + for char in sql: + statement.append(char) + statement_sql = "".join(statement).strip() + if statement_sql and sqlite3.complete_statement(statement_sql): + yield statement_sql + statement = [] + statement_sql = "".join(statement).strip() + if statement_sql: + yield statement_sql + + COLUMN_TYPE_MAPPING: Dict[Any, str] = { float: "REAL", int: "INTEGER", @@ -406,6 +419,38 @@ def close(self) -> None: "Close the SQLite connection, and the underlying database file" self.conn.close() + @contextlib.contextmanager + def atomic(self) -> Generator["Database", None, None]: + """ + Context manager for wrapping multiple database operations in a transaction. + + Nested blocks use SQLite savepoints. + """ + if self.conn.in_transaction: + savepoint = "sqlite_utils_{}".format(secrets.token_hex(16)) + self.conn.execute("SAVEPOINT {};".format(savepoint)) + try: + yield self + except BaseException: + self.conn.execute("ROLLBACK TO SAVEPOINT {};".format(savepoint)) + self.conn.execute("RELEASE SAVEPOINT {};".format(savepoint)) + raise + else: + self.conn.execute("RELEASE SAVEPOINT {};".format(savepoint)) + else: + self.conn.execute("BEGIN") + try: + yield self + except BaseException: + self.conn.rollback() + raise + else: + try: + self.conn.commit() + except BaseException: + self.conn.rollback() + raise + @contextlib.contextmanager def ensure_autocommit_off(self) -> Generator[None, None, None]: """ @@ -581,6 +626,15 @@ def executescript(self, sql: str) -> sqlite3.Cursor: """ if self._tracer: self._tracer(sql, None) + return self._executescript(sql) + + def _executescript(self, sql: str) -> sqlite3.Cursor: + if self.conn.in_transaction: + cursor = self.conn.cursor() + # avoid sqlite3.executescript()'s implicit commit: + for statement in _iter_complete_sql_statements(sql): + cursor.execute(statement) + return cursor return self.conn.executescript(sql) def table(self, table_name: str, **kwargs: Any) -> "Table": @@ -729,7 +783,7 @@ def supports_strict(self) -> bool: if not hasattr(self, "_supports_strict"): try: table_name = "t{}".format(secrets.token_hex(16)) - with self.conn: + with self.atomic(): self.conn.execute( "create table {} (name text) strict".format(table_name) ) @@ -745,7 +799,7 @@ def supports_on_conflict(self) -> bool: if not hasattr(self, "_supports_on_conflict"): table_name = "t{}".format(secrets.token_hex(16)) try: - with self.conn: + with self.atomic(): self.conn.execute( "create table {} (id integer primary key, name text)".format( table_name @@ -797,7 +851,7 @@ def disable_wal(self) -> None: self.execute("PRAGMA journal_mode=delete;") def _ensure_counts_table(self) -> None: - with self.conn: + with self.atomic(): self.execute(_COUNTS_TABLE_CREATE_SQL.format(self._counts_table_name)) def enable_counts(self) -> None: @@ -833,7 +887,7 @@ def cached_counts(self, tables: Optional[Iterable[str]] = None) -> Dict[str, int def reset_counts(self) -> None: "Re-calculate cached counts for tables." tables = [table for table in self.tables if table.has_counts_triggers] - with self.conn: + with self.atomic(): self._ensure_counts_table() counts_table = self.table(self._counts_table_name) counts_table.delete_where() @@ -1288,7 +1342,8 @@ def add_foreign_keys( for table, fks in by_table.items(): self.table(table).transform(add_foreign_keys=fks) - self.vacuum() + if not self.conn.in_transaction: + self.vacuum() def index_foreign_keys(self) -> None: "Create indexes for every foreign key column on every table in the database." @@ -1820,7 +1875,7 @@ def create( self._defaults["strict"] = strict columns = {name: value for (name, value) in columns.items()} - with self.db.conn: + with self.db.atomic(): self.db.create_table( self.name, columns, @@ -1848,7 +1903,7 @@ def duplicate(self, new_name: str) -> "Table": """ if not self.exists(): raise NoTable(f"Table {self.name} does not exist") - with self.db.conn: + with self.db.atomic(): sql = "CREATE TABLE {} AS SELECT * FROM {};".format( quote_identifier(new_name), quote_identifier(self.name), @@ -1905,20 +1960,40 @@ def transform( column_order=column_order, keep_table=keep_table, ) - pragma_foreign_keys_was_on = self.db.execute("PRAGMA foreign_keys").fetchone()[ - 0 - ] + pragma_foreign_keys_was_on = bool( + self.db.execute("PRAGMA foreign_keys").fetchone()[0] + ) + already_in_transaction = self.db.conn.in_transaction + should_disable_foreign_keys = ( + pragma_foreign_keys_was_on and not already_in_transaction + ) + should_defer_foreign_keys = ( + pragma_foreign_keys_was_on and already_in_transaction + ) + defer_foreign_keys_was_on = False try: - if pragma_foreign_keys_was_on: + if should_disable_foreign_keys: self.db.execute("PRAGMA foreign_keys=0;") - with self.db.conn: + elif should_defer_foreign_keys: + defer_foreign_keys_was_on = bool( + self.db.execute("PRAGMA defer_foreign_keys").fetchone()[0] + ) + if not defer_foreign_keys_was_on: + self.db.execute("PRAGMA defer_foreign_keys=ON;") + with self.db.atomic(): for sql in sqls: self.db.execute(sql) # Run the foreign_key_check before we commit if pragma_foreign_keys_was_on: - self.db.execute("PRAGMA foreign_key_check;") + foreign_key_violations = self.db.execute( + "PRAGMA foreign_key_check;" + ).fetchall() + if foreign_key_violations: + raise sqlite3.IntegrityError("FOREIGN KEY constraint failed") finally: - if pragma_foreign_keys_was_on: + if should_defer_foreign_keys and not defer_foreign_keys_was_on: + self.db.execute("PRAGMA defer_foreign_keys=OFF;") + if should_disable_foreign_keys: self.db.execute("PRAGMA foreign_keys=1;") return self @@ -2158,86 +2233,89 @@ def extract( columns, list(self.columns_dict.keys()) ) ) - table = table or "_".join(columns) - lookup_table = self.db.table(table) - fk_column = fk_column or "{}_id".format(table) - magic_lookup_column = "{}_{}".format(fk_column, os.urandom(6).hex()) - - # Populate the lookup table with all of the extracted unique values - lookup_columns_definition = { - (rename.get(col) or col): typ - for col, typ in self.columns_dict.items() - if col in columns - } - if lookup_table.exists(): - if not set(lookup_columns_definition.items()).issubset( - lookup_table.columns_dict.items() - ): - raise InvalidColumns( - "Lookup table {} already exists but does not have columns {}".format( - table, lookup_columns_definition + with self.db.atomic(): + table = table or "_".join(columns) + lookup_table = self.db.table(table) + fk_column = fk_column or "{}_id".format(table) + magic_lookup_column = "{}_{}".format(fk_column, os.urandom(6).hex()) + + # Populate the lookup table with all of the extracted unique values + lookup_columns_definition = { + (rename.get(col) or col): typ + for col, typ in self.columns_dict.items() + if col in columns + } + if lookup_table.exists(): + if not set(lookup_columns_definition.items()).issubset( + lookup_table.columns_dict.items() + ): + raise InvalidColumns( + "Lookup table {} already exists but does not have columns {}".format( + table, lookup_columns_definition + ) ) - ) - else: - lookup_table.create( - { - **{ - "id": int, + else: + lookup_table.create( + { + **{ + "id": int, + }, + **lookup_columns_definition, }, - **lookup_columns_definition, - }, - pk="id", - ) - lookup_columns = [(rename.get(col) or col) for col in columns] - lookup_table.create_index(lookup_columns, unique=True, if_not_exists=True) - self.db.execute( - "INSERT OR IGNORE INTO {} ({lookup_columns}) SELECT DISTINCT {table_cols} FROM {}".format( - quote_identifier(table), - quote_identifier(self.name), - lookup_columns=", ".join(quote_identifier(c) for c in lookup_columns), - table_cols=", ".join(quote_identifier(c) for c in columns), + pk="id", + ) + lookup_columns = [(rename.get(col) or col) for col in columns] + lookup_table.create_index(lookup_columns, unique=True, if_not_exists=True) + self.db.execute( + "INSERT OR IGNORE INTO {} ({lookup_columns}) SELECT DISTINCT {table_cols} FROM {}".format( + quote_identifier(table), + quote_identifier(self.name), + lookup_columns=", ".join( + quote_identifier(c) for c in lookup_columns + ), + table_cols=", ".join(quote_identifier(c) for c in columns), + ) ) - ) - # Now add the new fk_column - self.add_column(magic_lookup_column, int) + # Now add the new fk_column + self.add_column(magic_lookup_column, int) - # And populate it - self.db.execute( - "UPDATE {} SET {} = (SELECT id FROM {} WHERE {where})".format( - quote_identifier(self.name), - quote_identifier(magic_lookup_column), - quote_identifier(table), - where=" AND ".join( - "{}.{} IS {}.{}".format( - quote_identifier(self.name), - quote_identifier(column), - quote_identifier(table), - quote_identifier(rename.get(column) or column), - ) - for column in columns - ), + # And populate it + self.db.execute( + "UPDATE {} SET {} = (SELECT id FROM {} WHERE {where})".format( + quote_identifier(self.name), + quote_identifier(magic_lookup_column), + quote_identifier(table), + where=" AND ".join( + "{}.{} IS {}.{}".format( + quote_identifier(self.name), + quote_identifier(column), + quote_identifier(table), + quote_identifier(rename.get(column) or column), + ) + for column in columns + ), + ) ) - ) - # Figure out the right column order - column_order = [] - for c in self.columns: - if c.name in columns and magic_lookup_column not in column_order: - column_order.append(magic_lookup_column) - elif c.name == magic_lookup_column: - continue - else: - column_order.append(c.name) + # Figure out the right column order + column_order = [] + for c in self.columns: + if c.name in columns and magic_lookup_column not in column_order: + column_order.append(magic_lookup_column) + elif c.name == magic_lookup_column: + continue + else: + column_order.append(c.name) - # Drop the unnecessary columns and rename lookup column - self.transform( - drop=set(columns), - rename={magic_lookup_column: fk_column}, - column_order=column_order, - ) + # Drop the unnecessary columns and rename lookup column + self.transform( + drop=set(columns), + rename={magic_lookup_column: fk_column}, + column_order=column_order, + ) - # And add the foreign key constraint - self.add_foreign_key(fk_column, table, "id") + # And add the foreign key constraint + self.add_foreign_key(fk_column, table, "id") return self def create_index( @@ -2521,8 +2599,8 @@ def enable_counts(self) -> None: ), ) ) - with self.db.conn: - self.db.conn.executescript(sql) + with self.db.atomic(): + self.db._executescript(sql) self.db.use_counts_table = True @property @@ -2663,7 +2741,7 @@ def disable_fts(self) -> "Table": trigger_names = [] for row in self.db.execute(sql).fetchall(): trigger_names.append(row[0]) - with self.db.conn: + with self.db.atomic(): for trigger_name in trigger_names: self.db.execute( "DROP TRIGGER IF EXISTS {}".format(quote_identifier(trigger_name)) @@ -2862,7 +2940,7 @@ def delete(self, pk_values: Union[list, tuple, str, int, float]) -> "Table": sql = "delete from {} where {wheres}".format( quote_identifier(self.name), wheres=" and ".join(wheres) ) - with self.db.conn: + with self.db.atomic(): self.db.execute(sql, pk_values) return self @@ -2935,7 +3013,7 @@ def update( sets=", ".join(sets), wheres=" and ".join(wheres), ) - with self.db.conn: + with self.db.atomic(): try: rowcount = self.db.execute(sql, args).rowcount except OperationalError as e: @@ -3024,7 +3102,7 @@ def convert_value(v): ), where=" where {}".format(where) if where is not None else "", ) - with self.db.conn: + with self.db.atomic(): self.db.execute(sql, where_args or []) if drop: self.transform(drop=columns) @@ -3072,7 +3150,7 @@ def _convert_multi( with progressbar( length=self.count, silent=not show_progress, label="2: Updating" ) as bar: - with self.db.conn: + with self.db.atomic(): for pk, updates in pk_to_values.items(): self.update(pk, updates) bar.update(1) @@ -3291,7 +3369,7 @@ def insert_chunk( list_mode, ) result = None - with self.db.conn: + with self.db.atomic(): for query, params in queries_and_params: try: result = self.db.execute(query, params) @@ -3528,7 +3606,8 @@ def insert_all( self.last_rowid = None self.last_pk = None if truncate and self.exists(): - self.db.execute("DELETE FROM {};".format(quote_identifier(self.name))) + with self.db.atomic(): + self.db.execute("DELETE FROM {};".format(quote_identifier(self.name))) result = None for chunk in chunks(itertools.chain([first_record], records_iter), batch_size): chunk = list(chunk) diff --git a/tests/test_atomic.py b/tests/test_atomic.py new file mode 100644 index 000000000..acd547457 --- /dev/null +++ b/tests/test_atomic.py @@ -0,0 +1,174 @@ +import pytest + +from sqlite_utils.db import _iter_complete_sql_statements +from sqlite_utils.utils import sqlite3 + + +@pytest.mark.parametrize( + "sql,expected", + ( + ( + "CREATE TABLE t(id); INSERT INTO t VALUES (1)", + ["CREATE TABLE t(id);", "INSERT INTO t VALUES (1)"], + ), + ( + "INSERT INTO t VALUES ('a;b');", + ["INSERT INTO t VALUES ('a;b');"], + ), + ( + "-- comment;\nCREATE TABLE t(id);", + ["-- comment;\nCREATE TABLE t(id);"], + ), + ( + """ + CREATE TRIGGER t_ai AFTER INSERT ON t + BEGIN + UPDATE t SET value = 'a;b' WHERE id = new.id; + INSERT INTO log VALUES ('x;y'); + END; + """, + [ + "CREATE TRIGGER t_ai AFTER INSERT ON t\n" + " BEGIN\n" + " UPDATE t SET value = 'a;b' WHERE id = new.id;\n" + " INSERT INTO log VALUES ('x;y');\n" + " END;" + ], + ), + ), +) +def test_iter_complete_sql_statements(sql, expected): + assert list(_iter_complete_sql_statements(sql)) == expected + + +def test_atomic_commits(fresh_db): + with fresh_db.atomic(): + fresh_db["dogs"].insert({"id": 1, "name": "Cleo"}, pk="id") + + assert list(fresh_db["dogs"].rows) == [{"id": 1, "name": "Cleo"}] + + +def test_atomic_rolls_back(fresh_db): + with pytest.raises(RuntimeError): + with fresh_db.atomic(): + fresh_db["dogs"].insert({"id": 1, "name": "Cleo"}, pk="id") + raise RuntimeError("boom") + + assert not fresh_db["dogs"].exists() + + +def test_nested_atomic_rolls_back_to_savepoint(fresh_db): + fresh_db["dogs"].create({"id": int, "name": str}, pk="id") + + with fresh_db.atomic(): + fresh_db["dogs"].insert({"id": 1, "name": "Cleo"}) + with pytest.raises(RuntimeError): + with fresh_db.atomic(): + fresh_db["dogs"].insert({"id": 2, "name": "Pancakes"}) + raise RuntimeError("boom") + fresh_db["dogs"].insert({"id": 3, "name": "Marnie"}) + + assert list(fresh_db["dogs"].rows) == [ + {"id": 1, "name": "Cleo"}, + {"id": 3, "name": "Marnie"}, + ] + + +def test_outer_atomic_rolls_back_released_savepoint(fresh_db): + with pytest.raises(RuntimeError): + with fresh_db.atomic(): + fresh_db["dogs"].insert({"id": 1, "name": "Cleo"}, pk="id") + with fresh_db.atomic(): + fresh_db["dogs"].insert({"id": 2, "name": "Pancakes"}) + raise RuntimeError("boom") + + assert not fresh_db["dogs"].exists() + + +def test_executescript_does_not_commit_open_atomic_block(fresh_db): + with pytest.raises(RuntimeError): + with fresh_db.atomic(): + fresh_db.executescript(""" + CREATE TABLE dogs(id INTEGER PRIMARY KEY, name TEXT); + CREATE TRIGGER dogs_ai AFTER INSERT ON dogs + BEGIN + UPDATE dogs SET name = upper(new.name) || '; updated' WHERE id = new.id; + END; + -- This comment has a semicolon; + INSERT INTO dogs VALUES (1, 'Cleo; the first'); + """) + raise RuntimeError("boom") + + assert not fresh_db["dogs"].exists() + + +def test_transform_does_not_commit_open_atomic_block(fresh_db): + fresh_db["dogs"].insert({"id": 1, "name": "Cleo", "age": "5"}, pk="id") + + with pytest.raises(RuntimeError): + with fresh_db.atomic(): + fresh_db["dogs"].insert({"id": 2, "name": "Pancakes", "age": "6"}) + fresh_db["dogs"].transform(rename={"age": "dog_age"}) + raise RuntimeError("boom") + + assert ( + fresh_db["dogs"].schema + == 'CREATE TABLE "dogs" (\n "id" INTEGER PRIMARY KEY,\n "name" TEXT,\n "age" TEXT\n)' + ) + assert list(fresh_db["dogs"].rows) == [ + {"id": 1, "name": "Cleo", "age": "5"}, + ] + + +def test_transform_parent_table_with_foreign_keys_in_atomic(fresh_db): + fresh_db.conn.execute("PRAGMA foreign_keys=ON") + fresh_db["authors"].insert({"id": 1, "name": "Tina"}, pk="id") + fresh_db["books"].insert( + {"id": 1, "title": "Book", "author_id": 1}, + pk="id", + foreign_keys={"author_id"}, + ) + + with fresh_db.atomic(): + fresh_db["authors"].transform(rename={"name": "full_name"}) + assert fresh_db.conn.execute("PRAGMA foreign_keys").fetchone()[0] + + assert ( + fresh_db["authors"].schema + == 'CREATE TABLE "authors" (\n "id" INTEGER PRIMARY KEY,\n "full_name" TEXT\n)' + ) + assert fresh_db.execute("PRAGMA foreign_key_check").fetchall() == [] + + +def test_transform_parent_table_with_foreign_keys_rolls_back(fresh_db): + fresh_db.conn.execute("PRAGMA foreign_keys=ON") + fresh_db["authors"].insert({"id": 1, "name": "Tina"}, pk="id") + fresh_db["books"].insert( + {"id": 1, "title": "Book", "author_id": 1}, + pk="id", + foreign_keys={"author_id"}, + ) + + with pytest.raises(RuntimeError): + with fresh_db.atomic(): + fresh_db["authors"].transform(rename={"name": "full_name"}) + raise RuntimeError("boom") + + assert ( + fresh_db["authors"].schema + == 'CREATE TABLE "authors" (\n "id" INTEGER PRIMARY KEY,\n "name" TEXT\n)' + ) + assert fresh_db.conn.execute("PRAGMA foreign_keys").fetchone()[0] + assert fresh_db.execute("PRAGMA foreign_key_check").fetchall() == [] + + +def test_transform_detects_foreign_key_check_violations(fresh_db): + fresh_db.conn.execute("PRAGMA foreign_keys=ON") + fresh_db["authors"].insert({"id": 1, "name": "Tina"}, pk="id") + fresh_db["books"].insert({"id": 1, "author_id": 2}, pk="id") + + with pytest.raises(sqlite3.IntegrityError): + fresh_db["books"].transform(add_foreign_keys=(("author_id", "authors", "id"),)) + + assert fresh_db["books"].foreign_keys == [] + assert fresh_db.conn.execute("PRAGMA foreign_keys").fetchone()[0]