From b64cddce830daf550bb55325a596b7874715c76e Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Thu, 9 Apr 2026 19:04:50 +0000 Subject: [PATCH 01/22] V1 --- .../catalog/TransactionalCatalogPlugin.java | 39 ++ .../catalog/transactions/Transaction.java | 78 ++++ .../catalog/transactions/TransactionInfo.java | 30 ++ .../sql/catalyst/analysis/Analyzer.scala | 29 +- .../analysis/RelationResolution.scala | 15 +- .../UnresolveTransactionRelations.scala | 56 +++ .../catalyst/analysis/V2TableReference.scala | 28 +- .../catalyst/plans/logical/statements.scala | 2 +- .../catalyst/plans/logical/v2Commands.scala | 23 +- .../transactions/TransactionUtils.scala | 55 +++ .../connector/catalog/CatalogManager.scala | 8 +- .../sql/connector/catalog/LookupCatalog.scala | 13 + .../TransactionAwareCatalogManager.scala | 57 +++ .../transactions/TransactionInfoImpl.scala | 20 + .../transactions/TransactionUtilsSuite.scala | 124 +++++ .../connector/catalog/InMemoryBaseTable.scala | 3 + ...nMemoryRowLevelOperationTableCatalog.scala | 15 +- .../sql/connector/catalog/InMemoryTable.scala | 15 +- .../catalog/InMemoryTableCatalog.scala | 4 + .../spark/sql/connector/catalog/txns.scala | 147 ++++++ .../apache/spark/sql/classic/Catalog.scala | 17 + .../spark/sql/execution/CacheManager.scala | 3 + .../spark/sql/execution/QueryExecution.scala | 132 +++++- .../datasources/v2/DataSourceV2Strategy.scala | 9 +- .../datasources/v2/DeleteFromTableExec.scala | 9 +- .../v2/WriteToDataSourceV2Exec.scala | 36 +- .../AppendDataTransactionSuite.scala | 228 ++++++++++ .../connector/DataSourceV2OptionSuite.scala | 16 +- .../connector/DeleteFromTableSuiteBase.scala | 188 +++++++- .../DeltaBasedDeleteFromTableSuite.scala | 2 + .../DeltaBasedUpdateTableSuiteBase.scala | 2 + .../connector/MergeIntoDataFrameSuite.scala | 71 ++- .../connector/MergeIntoTableSuiteBase.scala | 427 +++++++++++++++++- .../RowLevelOperationSuiteBase.scala | 45 +- .../sql/connector/UpdateTableSuiteBase.scala | 323 ++++++++++++- .../benchmark/AnalyzerBenchmark.scala | 118 +++++ 36 files changed, 2310 insertions(+), 77 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java new file mode 100644 index 0000000000000..34a4fc68e9649 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +import org.apache.spark.sql.connector.catalog.transactions.Transaction; +import org.apache.spark.sql.connector.catalog.transactions.TransactionInfo; + +/** + * A {@link CatalogPlugin} that supports transactions. + *

+ * Catalogs that implement this interface opt in to transactional query execution. A catalog + * implementing this interface is responsible for starting transactions. + * + * @since 4.2.0 + */ +public interface TransactionalCatalogPlugin extends CatalogPlugin { + + /** + * Begins a new transaction and returns a {@link Transaction} representing it. + * + * @param info metadata about the transaction being started. + */ + Transaction beginTransaction(TransactionInfo info); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java new file mode 100644 index 0000000000000..80513aff31506 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.transactions; + +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin; + +import java.io.Closeable; + +/** + * Represents a transaction. + *

+ * Spark begins a transaction with {@link TransactionalCatalogPlugin#beginTransaction} and + * executes read/write operations against the transaction's catalog. On success, Spark + * calls {@link #commit()}; on failure, Spark calls {@link #abort()}. In both cases Spark + * subsequently calls {@link #close()} to release resources. + * + * @since 4.2.0 + */ +public interface Transaction extends Closeable { + + /** + * Returns the catalog associated with this transaction. This catalog is responsible for tracking + * read/write operations that occur within the boundaries of a transaction. This allows + * connectors to perform conflict resolution at commit time. + */ + CatalogPlugin catalog(); + + /** + * Commits the transaction. All writes performed under it become visible to other readers. + *

+ * The connector is responsible for detecting and resolving conflicting commits or throwing + * an exception if resolution is not possible. + *

+ * This method will be called exactly once per transaction. Spark calls {@link #close()} + * immediately after this method returns. + * + * @throws IllegalStateException if the transaction has already been committed or aborted. + */ + void commit(); + + /** + * Aborts the transaction, discarding any staged changes. + *

+ * This method must be idempotent. If the transaction has already been committed or aborted, + * invoking it must have no effect. + *

+ * Spark calls {@link #close()} immediately after this method returns. + */ + void abort(); + + /** + * Releases any resources held by this transaction. + *

+ * Spark always calls this method after {@link #commit()} or {@link #abort()}, regardless of + * whether those methods succeed or not. + *

+ * This method must be idempotent. If the transaction has already been closed, + * invoking it must have no effect. + */ + @Override + void close(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java new file mode 100644 index 0000000000000..a9c17d4b88274 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.transactions; + +/** + * Metadata about a transaction. + * + * @since 4.2.0 + */ +public interface TransactionInfo { + /** + * Returns a unique identifier for this transaction. + */ + String id(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3b4d725840935..9c3bc2c29ec56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -336,6 +336,31 @@ class Analyzer( } } + /** + * Returns a copy of this analyzer that uses the given [[CatalogManager]] for all catalog + * lookups. All other configuration (extended rules, checks, etc.) is preserved. Used by + * [[QueryExecution]] to create a per-query analyzer for transactional queries so that + * transaction-aware catalog resolution is an instance-level property rather than thread-local + * state. + */ + def withCatalogManager(newCatalogManager: CatalogManager): Analyzer = { + val self = this + new Analyzer(newCatalogManager, sharedRelationCache) { + override val hintResolutionRules: Seq[Rule[LogicalPlan]] = self.hintResolutionRules + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = self.extendedResolutionRules + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = self.postHocResolutionRules + override val extendedCheckRules: Seq[LogicalPlan => Unit] = self.extendedCheckRules + override val singlePassResolverExtensions: Seq[ResolverExtension] = + self.singlePassResolverExtensions + override val singlePassMetadataResolverExtensions: Seq[ResolverExtension] = + self.singlePassMetadataResolverExtensions + override val singlePassPostHocResolutionRules: Seq[Rule[LogicalPlan]] = + self.singlePassPostHocResolutionRules + override val singlePassExtendedResolutionChecks: Seq[LogicalPlan => Unit] = + self.singlePassExtendedResolutionChecks + } + } + override def execute(plan: LogicalPlan): LogicalPlan = { AnalysisContext.withNewAnalysisContext { executeSameContext(plan) @@ -437,7 +462,9 @@ class Analyzer( Batch("Simple Sanity Check", Once, LookupFunctions), Batch("Keep Legacy Outputs", Once, - KeepLegacyOutputs) + KeepLegacyOutputs), + Batch("Unresolve Relations", Once, + new UnresolveTransactionRelations(catalogManager)) ) override def batches: Seq[Batch] = earlyBatches ++ Seq( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala index e86248febd2eb..1f78ca3371f59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -384,6 +384,8 @@ class RelationResolution( } } + // TODO: how to validate the output is compatible? + // TODO: what shall we do if the output mismatches (schema changes?) def resolveReference(ref: V2TableReference): LogicalPlan = { val relation = getOrLoadRelation(ref) val planId = ref.getTagValue(LogicalPlan.PLAN_ID_TAG) @@ -391,6 +393,11 @@ class RelationResolution( } private def getOrLoadRelation(ref: V2TableReference): LogicalPlan = { + // Skip cache when a transaction is active. + if (catalogManager.transaction.isDefined) { + return loadRelation(ref) + } + val key = toCacheKey(ref.catalog, ref.identifier) relationCache.get(key) match { case Some(cached) => @@ -403,9 +410,13 @@ class RelationResolution( } private def loadRelation(ref: V2TableReference): LogicalPlan = { - val table = ref.catalog.loadTable(ref.identifier) + val resolvedCatalog = catalogManager.catalog(ref.catalog.name).asTableCatalog + val table = resolvedCatalog.loadTable(ref.identifier) + // val table = ref.catalog.loadTable(ref.identifier) V2TableReferenceUtils.validateLoadedTable(table, ref) - ref.toRelation(table) + // ref.toRelation(table) + DataSourceV2Relation( + table, ref.output, Some(resolvedCatalog), Some(ref.identifier), ref.options) } private def adaptCachedRelation(cached: LogicalPlan, ref: V2TableReference): LogicalPlan = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala new file mode 100644 index 0000000000000..4b175dd44ef08 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TransactionalWrite} +import org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.allowInvokingTransformsInAnalyzer +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +class UnresolveTransactionRelations(val catalogManager: CatalogManager) + extends Rule[LogicalPlan] with LookupCatalog { + + override def apply(plan: LogicalPlan): LogicalPlan = + catalogManager.transaction match { + case Some(transaction) => + allowInvokingTransformsInAnalyzer { + plan.transform { + case tw: TransactionalWrite => + unresolveRelations(tw, transaction.catalog) + } + } + case _ => plan + } + + private def unresolveRelations( + plan: LogicalPlan, + catalog: CatalogPlugin): LogicalPlan = { + plan transform { + case r: DataSourceV2Relation if isLoadedFromCatalog(r, catalog) => + V2TableReference.createForRelation(r, Seq.empty) + } + } + + private def isLoadedFromCatalog( + relation: DataSourceV2Relation, + catalog: CatalogPlugin): Boolean = { + // relation.catalog.exists(_ eq catalog) && relation.identifier.isDefined + relation.catalog.exists(_.name == catalog.name) && relation.identifier.isDefined + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala index 85c36d452b309..a2379f33e14ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.V2TableReference.Context import org.apache.spark.sql.catalyst.analysis.V2TableReference.TableInfo import org.apache.spark.sql.catalyst.analysis.V2TableReference.TemporaryViewContext +import org.apache.spark.sql.catalyst.analysis.V2TableReference.TestContext import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics @@ -37,7 +38,7 @@ import org.apache.spark.sql.connector.catalog.V2TableUtil import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.sql.util.SchemaValidationMode.ALLOW_NEW_TOP_LEVEL_FIELDS +import org.apache.spark.sql.util.SchemaValidationMode.{ALLOW_NEW_TOP_LEVEL_FIELDS, PROHIBIT_CHANGES} import org.apache.spark.util.ArrayImplicits._ /** @@ -84,11 +85,19 @@ private[sql] object V2TableReference { sealed trait Context case class TemporaryViewContext(viewName: Seq[String]) extends Context + // TODO(achatzis): Fix naming and complete implementation. + case class TestContext(tableName: Seq[String]) extends Context def createForTempView(relation: DataSourceV2Relation, viewName: Seq[String]): V2TableReference = { create(relation, TemporaryViewContext(viewName)) } + def createForRelation( + relation: DataSourceV2Relation, + relationName: Seq[String]): V2TableReference = { + create(relation, TestContext(relationName)) + } + private def create(relation: DataSourceV2Relation, context: Context): V2TableReference = { val ref = V2TableReference( relation.catalog.get.asTableCatalog, @@ -110,11 +119,28 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { ref.context match { case ctx: TemporaryViewContext => validateLoadedTableInTempView(table, ref, ctx) + case _: TestContext => + validateLoadedTableInTransaction(table, ref) case ctx => throw SparkException.internalError(s"Unknown table ref context: ${ctx.getClass.getName}") } } + private def validateLoadedTableInTransaction(table: Table, ref: V2TableReference): Unit = { + val dataErrors = V2TableUtil.validateCapturedColumns( + table, + ref.info.columns, + mode = PROHIBIT_CHANGES) + if (dataErrors.nonEmpty) { + throw QueryCompilationErrors.columnsChangedAfterAnalysis(ref.name, dataErrors) + } + + val metaErrors = V2TableUtil.validateCapturedMetadataColumns(table, ref.info.metadataColumns) + if (metaErrors.nonEmpty) { + throw QueryCompilationErrors.metadataColumnsChangedAfterAnalysis(ref.name, metaErrors) + } + } + private def validateLoadedTableInTempView( table: Table, ref: V2TableReference, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index c38377582c156..fb54af2344d1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -188,7 +188,7 @@ case class InsertIntoStatement( byName: Boolean = false, replaceCriteriaOpt: Option[InsertReplaceCriteria] = None, withSchemaEvolution: Boolean = false) - extends UnaryParsedStatement { + extends UnaryParsedStatement with TransactionalWrite { require(overwrite || !ifPartitionNotExists, "IF NOT EXISTS is only valid in INSERT OVERWRITE") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 500d648d23acb..5406c5d6a35fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -156,7 +156,7 @@ case class AppendData( isByName: Boolean, withSchemaEvolution: Boolean, write: Option[Write] = None, - analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand { + analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand with TransactionalWrite { override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT) override def withNewQuery(newQuery: LogicalPlan): AppendData = copy(query = newQuery) override def withNewTable(newTable: NamedRelation): AppendData = copy(table = newTable) @@ -204,7 +204,7 @@ case class OverwriteByExpression( isByName: Boolean, withSchemaEvolution: Boolean, write: Option[Write] = None, - analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand { + analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand with TransactionalWrite { override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE) override lazy val resolved: Boolean = { @@ -264,7 +264,7 @@ case class OverwritePartitionsDynamic( writeOptions: Map[String, String], isByName: Boolean, withSchemaEvolution: Boolean, - write: Option[Write] = None) extends V2WriteCommand { + write: Option[Write] = None) extends V2WriteCommand with TransactionalWrite { override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE) override def withNewQuery(newQuery: LogicalPlan): OverwritePartitionsDynamic = { @@ -943,7 +943,8 @@ object DescribeColumn { */ case class DeleteFromTable( table: LogicalPlan, - condition: Expression) extends UnaryCommand with SupportsSubquery { + condition: Expression) + extends UnaryCommand with TransactionalWrite with SupportsSubquery { override def child: LogicalPlan = table override protected def withNewChildInternal(newChild: LogicalPlan): DeleteFromTable = copy(table = newChild) @@ -965,7 +966,8 @@ case class DeleteFromTableWithFilters( case class UpdateTable( table: LogicalPlan, assignments: Seq[Assignment], - condition: Option[Expression]) extends UnaryCommand with SupportsSubquery { + condition: Option[Expression]) + extends UnaryCommand with TransactionalWrite with SupportsSubquery { lazy val aligned: Boolean = AssignmentUtils.aligned(table.output, assignments) @@ -998,8 +1000,13 @@ case class MergeIntoTable( notMatchedActions: Seq[MergeAction], notMatchedBySourceActions: Seq[MergeAction], withSchemaEvolution: Boolean) - extends BinaryCommand with WriteWithSchemaEvolution with SupportsSubquery { + extends BinaryCommand + with WriteWithSchemaEvolution + with TransactionalWrite + with SupportsSubquery { + // Implements SupportsSchemaEvolution.table. + // Implements TransactionalWrite.table, identifying the MERGE target as the table being written. override val table: LogicalPlan = EliminateSubqueryAliases(targetTable) override def withNewTable(newTable: NamedRelation): MergeIntoTable = { @@ -1270,6 +1277,10 @@ case class Assignment(key: Expression, value: Expression) extends Expression newLeft: Expression, newRight: Expression): Assignment = copy(key = newLeft, value = newRight) } +trait TransactionalWrite extends LogicalPlan { + def table: LogicalPlan +} + /** * The logical plan of the DROP TABLE command. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala new file mode 100644 index 0000000000000..d160aafdea34e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.transactions + +import java.util.UUID + +import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin +import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfoImpl} +import org.apache.spark.util.Utils + +object TransactionUtils { + def commit(transaction: Transaction): Unit = { + Utils.tryWithSafeFinally { + transaction.commit() + } { + transaction.close() + } + } + + def abort(transaction: Transaction): Unit = { + Utils.tryWithSafeFinally { + transaction.abort() + } { + transaction.close() + } + } + + def beginTransaction(catalog: TransactionalCatalogPlugin): Transaction = { + val info = TransactionInfoImpl(id = UUID.randomUUID.toString) + val transaction = catalog.beginTransaction(info) + if (transaction.catalog.name != catalog.name) { + abort(transaction) + throw new IllegalStateException( + s"""Transaction catalog name (${transaction.catalog.name}) + |must match original catalog name (${catalog.name}). + |""".stripMargin) + } + transaction + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index 8663bb65b6a88..152334fdfc600 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, TempVariableManager} import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -39,7 +40,7 @@ import org.apache.spark.sql.internal.SQLConf // need to track current database at all. private[sql] class CatalogManager( - defaultSessionCatalog: CatalogPlugin, + val defaultSessionCatalog: CatalogPlugin, val v1SessionCatalog: SessionCatalog) extends SQLConfHelper with Logging { import CatalogManager.SESSION_CATALOG_NAME import CatalogV2Util._ @@ -57,6 +58,11 @@ class CatalogManager( } } + def transaction: Option[Transaction] = None + + def withTransaction(transaction: Transaction): CatalogManager = + new TransactionAwareCatalogManager(this, transaction) + def isCatalogRegistered(name: String): Boolean = { try { catalog(name) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala index 203cfc23452a8..fbb2938fd3da2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.connector.catalog import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} +import org.apache.spark.sql.catalyst.plans.logical.TransactionalWrite import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -163,4 +165,15 @@ private[sql] trait LookupCatalog extends Logging { } } } + + object TransactionalWrite { + def unapply(write: TransactionalWrite): Option[TransactionalCatalogPlugin] = { + EliminateSubqueryAliases(write.table) match { + case UnresolvedRelation(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _, _) => + Some(c) + case _ => + None + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala new file mode 100644 index 0000000000000..9403219f596da --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import org.apache.spark.sql.connector.catalog.transactions.Transaction + +/** + * A [[CatalogManager]] decorator that redirects catalog lookups to the transaction's catalog + * instance when names match, ensuring table loads during analysis are scoped to the transaction. + * All mutable state (current catalog, current namespace, loaded catalogs) is delegated to the + * wrapped [[CatalogManager]]. + */ +// TODO: Consider extracting a CatalogManager trait that both the real +// implementation and the decorator implement +private[sql] class TransactionAwareCatalogManager( + delegate: CatalogManager, + txn: Transaction) + extends CatalogManager(delegate.defaultSessionCatalog, delegate.v1SessionCatalog) { + + override def transaction: Option[Transaction] = Some(txn) + + override def catalog(name: String): CatalogPlugin = { + val resolved = delegate.catalog(name) + if (txn.catalog.name() == resolved.name()) txn.catalog else resolved + } + + override def currentCatalog: CatalogPlugin = { + val c = delegate.currentCatalog + if (txn.catalog.name() == c.name()) txn.catalog else c + } + + override def currentNamespace: Array[String] = delegate.currentNamespace + + override def setCurrentNamespace(namespace: Array[String]): Unit = + delegate.setCurrentNamespace(namespace) + + override def setCurrentCatalog(catalogName: String): Unit = + delegate.setCurrentCatalog(catalogName) + + override def listCatalogs(pattern: Option[String]): Seq[String] = + delegate.listCatalogs(pattern) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala new file mode 100644 index 0000000000000..4cb53da0a59e2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.transactions + +case class TransactionInfoImpl(id: String) extends TransactionInfo diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala new file mode 100644 index 0000000000000..d409316e667b1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.transactions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, TransactionalCatalogPlugin} +import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class TransactionUtilsSuite extends SparkFunSuite { + val testCatalogName = "test_catalog" + + // --- Helpers --------------------------------------------------------------- + private def mockCatalog(catalogName: String): CatalogPlugin = new CatalogPlugin { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = () + override def name(): String = catalogName + } + + private val emptyFunction = () => () + private class TestTransaction( + catalogName: String, + onCommit: () => Unit = emptyFunction, + onAbort: () => Unit = emptyFunction, + onClose: () => Unit = emptyFunction) extends Transaction { + var committed = false + var aborted = false + var closed = false + + override def catalog(): CatalogPlugin = mockCatalog(catalogName) + override def commit(): Unit = { committed = true; onCommit() } + override def abort(): Unit = { aborted = true; onAbort() } + override def close(): Unit = { closed = true; onClose() } + } + + private def mockTransactionalCatalog( + catalogName: String, + txnCatalogName: String = null): TransactionalCatalogPlugin = { + val resolvedTxnCatalogName = Option(txnCatalogName).getOrElse(catalogName) + new TransactionalCatalogPlugin { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = () + override def name(): String = catalogName + override def beginTransaction(info: TransactionInfo): Transaction = + new TestTransaction(resolvedTxnCatalogName) + } + } + + // --- Commit ---------------------------------------------------------------- + test("commit: calls commit then close") { + val txn = new TestTransaction(testCatalogName) + TransactionUtils.commit(txn) + assert(txn.committed) + assert(txn.closed) + } + + test("commit: close is called even if commit fails") { + val txn = new TestTransaction( + testCatalogName, onCommit = () => throw new RuntimeException("commit failed")) + intercept[RuntimeException] { TransactionUtils.commit(txn) } + assert(txn.closed) + } + + // --- Abort ----------------------------------------------------------------- + test("abort: calls abort then close") { + val txn = new TestTransaction(testCatalogName) + TransactionUtils.abort(txn) + assert(txn.aborted) + assert(txn.closed) + } + + test("abort: close is called even if abort fails") { + val txn = new TestTransaction(testCatalogName, + onAbort = () => throw new RuntimeException("abort failed")) + intercept[RuntimeException] { TransactionUtils.abort(txn) } + assert(txn.closed) + } + + // --- Begin Transaction ----------------------------------------------------- + test("beginTransaction: returns transaction when catalog names match") { + val catalog = mockTransactionalCatalog(testCatalogName) + val txn = TransactionUtils.beginTransaction(catalog) + assert(txn.catalog().name() == testCatalogName) + } + + test("beginTransaction: fails when transaction catalog name does not match") { + val catalog = mockTransactionalCatalog(catalogName = testCatalogName, txnCatalogName = "other") + val e = intercept[IllegalStateException] { + TransactionUtils.beginTransaction(catalog) + } + assert(e.getMessage.contains("other")) + assert(e.getMessage.contains(testCatalogName)) + } + + test("beginTransaction: aborts and closes transaction on catalog name mismatch") { + var aborted = false + var closed = false + val catalog = new TransactionalCatalogPlugin { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = () + override def name(): String = testCatalogName + override def beginTransaction(info: TransactionInfo): Transaction = + new TestTransaction( + "other", + onAbort = () => { aborted = true }, + onClose = () => { closed = true }) + } + intercept[IllegalStateException] { TransactionUtils.beginTransaction(catalog) } + assert(aborted) + assert(closed) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 8791677999810..422632d074a3b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -94,6 +94,8 @@ abstract class InMemoryBaseTable( validatedTableVersion = version } + protected def recordScanEvent(filters: Array[Filter]): Unit = {} + protected object PartitionKeyColumn extends MetadataColumn { override def name: String = "_partition" override def dataType: DataType = StringType @@ -455,6 +457,7 @@ abstract class InMemoryBaseTable( if (evaluableFilters.nonEmpty) { scan.filter(evaluableFilters) } + recordScanEvent(_pushedFilters) scan } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index bbb9041bab37c..27231447a1273 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -18,10 +18,23 @@ package org.apache.spark.sql.connector.catalog import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} -class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog { +class InMemoryRowLevelOperationTableCatalog + extends InMemoryTableCatalog with TransactionalCatalogPlugin { import CatalogV2Implicits._ + var transaction: Txn = _ + // Tracks the last completed transaction for test assertions; cleared when a new one begins. + var lastTransaction: Txn = _ + + override def beginTransaction(info: TransactionInfo): Transaction = { + assert(transaction == null || transaction.currentState != Active) + this.transaction = new Txn(new TxnTableCatalog(this)) + this.lastTransaction = transaction + transaction + } + override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { if (tables.containsKey(ident)) { throw new TableAlreadyExistsException(ident.asMultipartIdentifier) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index d5738475031dc..2f3c65924d6a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ /** @@ -215,6 +216,16 @@ class InMemoryTable( object InMemoryTable { + // V1 filter values (from PredicateUtils.toV1) are Scala types (e.g. String), but partition + // keys stored in dataMap are Catalyst internal types (e.g. UTF8String). Normalize both sides + // before comparing so that string partitions work correctly. + private def valuesEqual(filterValue: Any, partitionValue: Any): Boolean = + (filterValue, partitionValue) match { + case (s: String, u: UTF8String) => u.toString == s + case (u: UTF8String, s: String) => u.toString == s + case _ => filterValue == partitionValue + } + def filtersToKeys( keys: Iterable[Seq[Any]], partitionNames: Seq[String], @@ -222,7 +233,7 @@ object InMemoryTable { keys.filter { partValues => filters.flatMap(splitAnd).forall { case EqualTo(attr, value) => - value == InMemoryBaseTable.extractValue(attr, partitionNames, partValues) + valuesEqual(value, InMemoryBaseTable.extractValue(attr, partitionNames, partValues)) case EqualNullSafe(attr, value) => val attrVal = InMemoryBaseTable.extractValue(attr, partitionNames, partValues) if (attrVal == null && value == null) { @@ -230,7 +241,7 @@ object InMemoryTable { } else if (attrVal == null || value == null) { false } else { - value == attrVal + valuesEqual(value, attrVal) } case IsNull(attr) => null == InMemoryBaseTable.extractValue(attr, partitionNames, partValues) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index ff7995ad6697e..c7195b512b8d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -59,6 +59,10 @@ class BasicInMemoryTableCatalog extends TableCatalog { tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray } + def loadTableAs[T <: Table](ident: Identifier): T = { + loadTable(ident).asInstanceOf[T] + } + // load table for scans override def loadTable(ident: Identifier): Table = { Option(tables.get(ident)) match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala new file mode 100644 index 0000000000000..4feb89c78f56c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.connector.catalog.transactions.Transaction +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +sealed trait TransactionState +case object Active extends TransactionState +case object Committed extends TransactionState +case object Aborted extends TransactionState + +class Txn(override val catalog: TxnTableCatalog) extends Transaction { + + private[this] var state: TransactionState = Active + private[this] var closed: Boolean = false + + def currentState: TransactionState = state + + def isClosed: Boolean = closed + + override def commit(): Unit = { + if (closed) throw new IllegalStateException("Can't commit, already closed") + catalog.commit() + this.state = Committed + } + + override def abort(): Unit = { + if (state == Committed || state == Aborted) return + // if (closed) throw new IllegalStateException("Can't abort, already closed") + this.state = Aborted + } + + override def close(): Unit = { + catalog.clearActiveTransaction() + this.closed = true + } +} + +// a special table used in row-level operation transactions +// it inherits data from the base table upon construction and +// propagates staged transaction state back after an explicit commit +class TxnTable(val delegate: InMemoryRowLevelOperationTable) + extends InMemoryRowLevelOperationTable( + delegate.name, + delegate.schema, + delegate.partitioning, + delegate.properties, + delegate.constraints) { + + // TODO(achatzis): Rethink how schema evolution works on top of transactions. + alterTableWithData(delegate.data, schema) + + // a tracker of filters used in each scan + // achatzis: Non-deterministic filters? + val scanEvents = new ArrayBuffer[Array[Filter]]() + + override protected def recordScanEvent(filters: Array[Filter]): Unit = { + scanEvents += filters + } + + def commit(): Unit = { + delegate.dataMap.clear() + // TODO(achatzis): Rethink how schema evolution works on top of transactions. + delegate.alterTableWithData(data, delegate.schema) + delegate.replacedPartitions = replacedPartitions + delegate.lastWriteInfo = lastWriteInfo + delegate.lastWriteLog = lastWriteLog + delegate.commits ++= commits + delegate.increaseVersion() + } +} + +// a special table catalog used in row-level operation transactions +// table changes are initially staged in memory and propagated only after an explicit commit +class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends TableCatalog { + + private val tables: util.Map[Identifier, TxnTable] = new ConcurrentHashMap[Identifier, TxnTable]() + + override def name: String = delegate.name + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} + + override def listTables(namespace: Array[String]): Array[Identifier] = { + throw new UnsupportedOperationException() + } + + override def loadTable(ident: Identifier): Table = { + tables.computeIfAbsent(ident, _ => { + val table = delegate.loadTableAs[InMemoryRowLevelOperationTable](ident) + new TxnTable(table) + }) + } + + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + val newDelegateTable = delegate.alterTable(ident, changes: _*) + // Compute again if absent. + tables.remove(ident) + newDelegateTable + } + + override def dropTable(ident: Identifier): Boolean = { + throw new UnsupportedOperationException() + } + + override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { + throw new UnsupportedOperationException() + } + + def commit(): Unit = { + tables.values.forEach(table => table.commit()) + } + + def clearActiveTransaction(): Unit = { + delegate.transaction = null + } + + override def equals(obj: Any): Boolean = { + obj match { + case that: CatalogPlugin => this.name == that.name + case _ => false + } + } + + override def hashCode(): Int = name.hashCode() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala index 35041feca9e18..81fc534423d8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala @@ -931,6 +931,23 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog with Logging { // caches referencing this relation. If this relation is cached as an InMemoryRelation, // this will clear the relation cache and caches of all its dependents. CommandUtils.recacheTableOrView(sparkSession, relation) + /* + EliminateSubqueryAliases(relation) match { + case r @ ExtractV2CatalogAndIdentifier(catalog, ident) if r.timeTravelSpec.isEmpty => + val nameParts = ident.toQualifiedNameParts(catalog) + sparkSession.sharedState.cacheManager.recacheTableOrView(sparkSession, nameParts) + case _ => + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, relation) + */ + /* + relation match { + case r: DataSourceV2Relation if r.catalog.isDefined && r.identifier.isDefined => + val nameParts = r.identifier.get.toQualifiedNameParts(r.catalog.get) + sparkSession.sharedState.cacheManager.recacheTableOrView(sparkSession, nameParts) + case _ => + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, relation) + } + */ } private def resolveRelation(tableName: String): LogicalPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 3f92f24156d3c..66f406d39f263 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -260,6 +260,9 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { val nameInCache = v2Ident.toQualifiedNameParts(catalog) isSameName(name, nameInCache, resolver) && (includeTimeTravel || timeTravelSpec.isEmpty) + // case r: TableReference => + // isSameName(name, r.identifier.toQualifiedNameParts(r.catalog), resolver) + case v: View => isSameName(name, v.desc.identifier.nameParts, resolver) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index f08b561d6ef9a..32ffbfc32acb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -27,23 +27,25 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.apache.spark.SparkException -import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} -import org.apache.spark.sql.catalyst.analysis.{LazyExpression, NameParameterizedQuery, UnsupportedOperationChecker} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, LazyExpression, NameParameterizedQuery, UnsupportedOperationChecker} import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union, WithCTE} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union, UnresolvedWith, WithCTE} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} +import org.apache.spark.sql.catalyst.transactions.TransactionUtils import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.connector.catalog.LookupCatalog +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.execution.SQLExecution.EXECUTION_ROOT_ID_KEY import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan} import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan} -import org.apache.spark.sql.execution.datasources.v2.V2TableRefreshUtil +import org.apache.spark.sql.execution.datasources.v2.{TransactionalExec, V2TableRefreshUtil} import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery @@ -69,7 +71,8 @@ class QueryExecution( val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL, val shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None, val refreshPhaseEnabled: Boolean = true, - val queryId: UUID = UUIDv7Generator.generate()) extends Logging { + val queryId: UUID = UUIDv7Generator.generate(), + val analyzerOpt: Option[Analyzer] = None) extends LookupCatalog { val id: Long = QueryExecution.nextExecutionId @@ -79,6 +82,8 @@ class QueryExecution( // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner + protected val catalogManager = sparkSession.sessionState.catalogManager + /** * Check whether the query represented by this QueryExecution is a SQL script. * @return True if the query is a SQL script, False otherwise. @@ -90,6 +95,46 @@ class QueryExecution( logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyExpression]))) } + + // 1. At the pre-Analyzed plan we look for nodes that implement the TransactionalWrite trait. + // When a plan contains such a node we initiate a transaction. Note, we should never start + // a transaction for operations that are not executed, e.g. EXPLAIN. + // 2. Create an analyzer clone with a transaction aware Catalog Manager. The latter is the single + // choke point of all catalog access, and it is also the transaction context carrier. + // This is then passed to all rules during analysis that need to check the catalog. Rules + // that are specifically interested in transactionality can access the transaction directly + // from the Catalog Manager. The transaction catalog, is potentially the place where connectors + // should keep state about the reads (tables+predicates) that occurred during the transaction. + // 3. The analyzer instance is passed to nested Query Execution instances. These need to respect + // the open transaction instead of creating their own. + private lazy val transactionOpt: Option[Transaction] = + // Always inherit an active transaction from the outer analyzer, regardless of mode. + analyzerOpt.flatMap(_.catalogManager.transaction).orElse { + // Only begin a new transaction for outer QEs that lead to execution. + if (mode != CommandExecutionMode.SKIP) { + val catalog = logical match { + case UnresolvedWith(TransactionalWrite(c), _, _) => Some(c) + case TransactionalWrite(c) => Some(c) + case _ => None + } + catalog.map(TransactionUtils.beginTransaction) + } else { + None + } + } + + // Per-query analyzer: uses a transaction-aware CatalogManager when a transaction is active, + // so that all catalog lookups and rule applications during analysis see the correct state + // without relying on thread-local context. + private lazy val analyzer: Analyzer = analyzerOpt.getOrElse { + transactionOpt match { + case Some(txn) => + sparkSession.sessionState.analyzer.withCatalogManager(catalogManager.withTransaction(txn)) + case None => + sparkSession.sessionState.analyzer + } + } + def assertAnalyzed(): Unit = { try { analyzed @@ -102,7 +147,7 @@ class QueryExecution( } } - def assertSupported(): Unit = { + def assertSupported(): Unit = executeWithTransactionContext { if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { UnsupportedOperationChecker.checkForBatch(analyzed) } @@ -141,7 +186,7 @@ class QueryExecution( try { val plan = executePhase(QueryPlanningTracker.ANALYSIS) { // We can't clone `logical` here, which will reset the `_analyzed` flag. - sparkSession.sessionState.analyzer.executeAndCheck(sqlScriptExecuted, tracker) + analyzer.executeAndCheck(sqlScriptExecuted, tracker) } tracker.setAnalyzed(plan) plan @@ -152,7 +197,9 @@ class QueryExecution( } } - def analyzed: LogicalPlan = lazyAnalyzed.get + def analyzed: LogicalPlan = executeWithTransactionContext { + lazyAnalyzed.get + } private val lazyCommandExecuted = LazyTry { mode match { @@ -162,7 +209,9 @@ class QueryExecution( } } - def commandExecuted: LogicalPlan = lazyCommandExecuted.get + def commandExecuted: LogicalPlan = executeWithTransactionContext { + lazyCommandExecuted.get + } private def commandExecutionName(command: Command): String = command match { case _: CreateTableAsSelect => "create" @@ -184,7 +233,8 @@ class QueryExecution( // for eagerly executed commands we mark this place as beginning of execution. tracker.setReadyForExecution() val (qe, result) = QueryExecution.runCommand( - sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode)) + sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode), + analyzerOpt = Some(analyzer)) CommandResult( qe.analyzed.output, qe.commandExecuted, @@ -222,7 +272,9 @@ class QueryExecution( } // The plan that has been normalized by custom rules, so that it's more likely to hit cache. - def normalized: LogicalPlan = lazyNormalized.get + def normalized: LogicalPlan = executeWithTransactionContext { + lazyNormalized.get + } private val lazyWithCachedData = LazyTry { sparkSession.withActive { @@ -230,11 +282,19 @@ class QueryExecution( assertSupported() // clone the plan to avoid sharing the plan instance between different stages like analyzing, // optimizing and planning. - sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) + val plan = normalized.clone() + // During a transaction, skip cache substitution. useCachedData replaces DataSourceV2Relation + // nodes (loaded via the transaction catalog) with InMemoryRelation, which bypasses read + // tracking in the transaction catalog and may serve stale data. + // if (transactionOpt.isDefined) plan + // else sparkSession.sharedState.cacheManager.useCachedData(plan) + sparkSession.sharedState.cacheManager.useCachedData(plan) } } - def withCachedData: LogicalPlan = lazyWithCachedData.get + def withCachedData: LogicalPlan = executeWithTransactionContext { + lazyWithCachedData.get + } def assertCommandExecuted(): Unit = commandExecuted @@ -256,7 +316,9 @@ class QueryExecution( } } - def optimizedPlan: LogicalPlan = lazyOptimizedPlan.get + def optimizedPlan: LogicalPlan = executeWithTransactionContext { + lazyOptimizedPlan.get + } def assertOptimized(): Unit = optimizedPlan @@ -264,14 +326,21 @@ class QueryExecution( // We need to materialize the optimizedPlan here because sparkPlan is also tracked under // the planning phase assertOptimized() - executePhase(QueryPlanningTracker.PLANNING) { + val plan = executePhase(QueryPlanningTracker.PLANNING) { // Clone the logical plan here, in case the planner rules change the states of the logical // plan. QueryExecution.createSparkPlan(planner, optimizedPlan.clone()) } + transactionOpt match { + case Some(txn) => + plan.transformDown { case w: TransactionalExec => w.withTransaction(Some(txn)) } + case None => plan + } } - def sparkPlan: SparkPlan = lazySparkPlan.get + def sparkPlan: SparkPlan = executeWithTransactionContext { + lazySparkPlan.get + } def assertSparkPlanPrepared(): Unit = sparkPlan @@ -292,7 +361,9 @@ class QueryExecution( // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - def executedPlan: SparkPlan = lazyExecutedPlan.get + def executedPlan: SparkPlan = executeWithTransactionContext { + lazyExecutedPlan.get + } def assertExecutedPlanPrepared(): Unit = executedPlan @@ -310,7 +381,9 @@ class QueryExecution( * Given QueryExecution is not a public class, end users are discouraged to use this: please * use `Dataset.rdd` instead where conversion will be applied. */ - def toRdd: RDD[InternalRow] = lazyToRdd.get + def toRdd: RDD[InternalRow] = executeWithTransactionContext { + lazyToRdd.get + } private val observedMetricsLock = new Object @@ -512,6 +585,23 @@ class QueryExecution( } } + /** + * Execute the given block with the transaction context if exists. If there is an exception thrown + * during the execution, the transaction will be aborted. + * + * Note 1: The transaction is not committed in this method. The caller should commit the + * transaction if the execution is successful. + * + * Note 2: In some cases, post commit execution might generate an exception. The abort operation + * should be no-op in this case. + */ + private def executeWithTransactionContext[T](block: => T): T = transactionOpt match { + case Some(transaction) => + try block + catch { case e: Throwable => TransactionUtils.abort(transaction); throw e } + case None => block + } + /** A special namespace for commands that can be used to debug query execution. */ // scalastyle:off object debug { @@ -796,14 +886,16 @@ object QueryExecution { name: String, refreshPhaseEnabled: Boolean = true, mode: CommandExecutionMode.Value = CommandExecutionMode.SKIP, - shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None) + shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None, + analyzerOpt: Option[Analyzer] = None) : (QueryExecution, Array[InternalRow]) = { val qe = new QueryExecution( sparkSession, command, mode = mode, shuffleCleanupModeOpt = shuffleCleanupModeOpt, - refreshPhaseEnabled = refreshPhaseEnabled) + refreshPhaseEnabled = refreshPhaseEnabled, + analyzerOpt = analyzerOpt) val result = QueryExecution.withInternalError(s"Executed $name failed.") { SQLExecution.withNewExecutionId(qe, Some(name)) { qe.executedPlan.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 3d3b4d1cae11c..82fd732e9de8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -62,9 +62,6 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat private def hadoopConf = session.sessionState.newHadoopConf() - // recaches all cache entries without time travel for the given table - // after a write operation that moves the state of the table forward (e.g. append, overwrite) - // this method accounts for V2 tables loaded via TableProvider (no catalog/identifier) private def refreshCache(r: DataSourceV2Relation)(): Unit = r match { case ExtractV2CatalogAndIdentifier(catalog, ident) => val nameParts = ident.toQualifiedNameParts(catalog) @@ -381,12 +378,14 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, _, Some(write)) => // use the original relation to refresh the cache - ReplaceDataExec(planLater(query), refreshCache(r), projections, write) :: Nil + ReplaceDataExec( + planLater(query), refreshCache(r), projections, write) :: Nil case WriteDelta(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, Some(write)) => // use the original relation to refresh the cache - WriteDeltaExec(planLater(query), refreshCache(r), projections, write) :: Nil + WriteDeltaExec( + planLater(query), refreshCache(r), projections, write) :: Nil case MergeRows(isSourceRowPresent, isTargetRowPresent, matchedInstructions, notMatchedInstructions, notMatchedBySourceInstructions, checkCardinality, output, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala index 8d5ee6038e80f..c6b1bae89b156 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala @@ -19,16 +19,23 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.transactions.TransactionUtils import org.apache.spark.sql.connector.catalog.SupportsDeleteV2 +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.connector.expressions.filter.Predicate case class DeleteFromTableExec( table: SupportsDeleteV2, condition: Array[Predicate], - refreshCache: () => Unit) extends LeafV2CommandExec { + refreshCache: () => Unit, + transaction: Option[Transaction] = None) extends LeafV2CommandExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): DeleteFromTableExec = + copy(transaction = txn) override protected def run(): Seq[InternalRow] = { table.deleteWhere(condition) + transaction.foreach(TransactionUtils.commit) refreshCache() Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 804e694f92e9c..a548749a972e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -26,9 +26,11 @@ import org.apache.spark.sql.catalyst.{InternalRow, ProjectingInternalRow} import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, TableSpec, UnaryNode} +import org.apache.spark.sql.catalyst.transactions.TransactionUtils import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, ReplaceDataProjections, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege} +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, Write, WriterCommitMessage, WriteSummary} @@ -271,7 +273,9 @@ case class AtomicReplaceTableAsSelectExec( case class AppendDataExec( query: SparkPlan, refreshCache: () => Unit, - write: Write) extends V2ExistingTableWriteExec { + write: Write, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { + override def withTransaction(txn: Option[Transaction]): AppendDataExec = copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): AppendDataExec = copy(query = newChild) } @@ -289,7 +293,10 @@ case class AppendDataExec( case class OverwriteByExpressionExec( query: SparkPlan, refreshCache: () => Unit, - write: Write) extends V2ExistingTableWriteExec { + write: Write, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { + override def withTransaction(txn: Option[Transaction]): OverwriteByExpressionExec = + copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): OverwriteByExpressionExec = copy(query = newChild) } @@ -306,7 +313,10 @@ case class OverwriteByExpressionExec( case class OverwritePartitionsDynamicExec( query: SparkPlan, refreshCache: () => Unit, - write: Write) extends V2ExistingTableWriteExec { + write: Write, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { + override def withTransaction(txn: Option[Transaction]): OverwritePartitionsDynamicExec = + copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): OverwritePartitionsDynamicExec = copy(query = newChild) } @@ -318,7 +328,8 @@ case class ReplaceDataExec( query: SparkPlan, refreshCache: () => Unit, projections: ReplaceDataProjections, - write: Write) extends V2ExistingTableWriteExec { + write: Write, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { override def writingTask: WritingSparkTask[_] = { projections match { @@ -329,6 +340,7 @@ case class ReplaceDataExec( } } + override def withTransaction(txn: Option[Transaction]): ReplaceDataExec = copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): ReplaceDataExec = { copy(query = newChild) } @@ -341,7 +353,8 @@ case class WriteDeltaExec( query: SparkPlan, refreshCache: () => Unit, projections: WriteDeltaProjections, - write: DeltaWrite) extends V2ExistingTableWriteExec { + write: DeltaWrite, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { override lazy val writingTask: WritingSparkTask[_] = { if (projections.metadataProjection.isDefined) { @@ -351,6 +364,7 @@ case class WriteDeltaExec( } } + override def withTransaction(txn: Option[Transaction]): WriteDeltaExec = copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): WriteDeltaExec = { copy(query = newChild) } @@ -378,7 +392,16 @@ case class WriteToDataSourceV2Exec( copy(query = newChild) } -trait V2ExistingTableWriteExec extends V2TableWriteExec { +/** + * Trait for physical plan nodes that write to an existing table as part of a transaction. + * The [[transaction]] is injected post-planning by [[QueryExecution]]. + */ +trait TransactionalExec extends SparkPlan { + def transaction: Option[Transaction] + def withTransaction(txn: Option[Transaction]): SparkPlan +} + +trait V2ExistingTableWriteExec extends V2TableWriteExec with TransactionalExec { def refreshCache: () => Unit def write: Write @@ -395,6 +418,7 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec { } finally { postDriverMetrics() } + transaction.foreach(TransactionUtils.commit) refreshCache() writtenRows } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala new file mode 100644 index 0000000000000..1c8e7fc5a0fd4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.Committed +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf + +class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { + + test("writeTo append with transactional checks") { + // create table with initial data + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // create a source on top of itself that will be fully resolved and analyzed + val sourceDF = spark.table(tableNameAsString) + .where("pk == 1") + .select(col("pk") + 10 as "pk", col("salary"), col("dep")) + sourceDF.queryExecution.assertAnalyzed() + + // append data using the DataFrame API + val (txn, txnTables) = executeTransaction { + sourceDF.writeTo(tableNameAsString).append() + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + + // check the source scan was tracked via the transaction catalog + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size >= 1) + + // check data was appended correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(11, 100, "hr"))) // appended + } + + test("SQL INSERT INTO with transactional checks") { + // create table with initial data + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // SQL INSERT INTO using VALUES + val (txn, _) = executeTransaction { + sql(s"INSERT INTO $tableNameAsString VALUES (3, 300, 'hr'), (4, 400, 'finance')") + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + + // check data was inserted correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 300, "hr"), + Row(4, 400, "finance"))) + } + + test("SQL INSERT OVERWRITE with transactional checks") { + // create table with initial data; table is partitioned by dep + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // INSERT OVERWRITE with static partition predicate -> OverwriteByExpression + val (txn, _) = executeTransaction { + sql(s"""INSERT OVERWRITE $tableNameAsString + |PARTITION (dep = 'hr') + |SELECT pk + 10, salary FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(11, 100, "hr"), // overwritten + Row(13, 300, "hr"))) // overwritten + } + + test("SQL INSERT OVERWRITE dynamic partition with transactional checks") { + // create table with initial data; table is partitioned by dep + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // INSERT OVERWRITE with dynamic partitioning -> OverwritePartitionsDynamic + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> "dynamic") { + val (txn, _) = executeTransaction { + sql(s"""INSERT OVERWRITE $tableNameAsString + |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged (different partition) + Row(11, 100, "hr"), // overwrote hr partition + Row(13, 300, "hr"))) // overwrote hr partition + } + } + + test("writeTo overwrite with transactional checks") { + // create table with initial data; table is partitioned by dep + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // overwrite using a condition that covers the hr partition -> OverwriteByExpression + val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))). + toDF("pk", "salary", "dep") + + val (txn, _) = executeTransaction { + sourceDF.writeTo(tableNameAsString).overwrite(col("dep") === "hr") + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged (different partition) + Row(11, 999, "hr"), // overwrote hr partition + Row(12, 888, "hr"))) // overwrote hr partition + } + + test("writeTo overwritePartitions with transactional checks") { + // create table with initial data; table is partitioned by dep + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // overwrite partitions dynamically -> OverwritePartitionsDynamic + val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))). + toDF("pk", "salary", "dep") + + val (txn, _) = executeTransaction { + sourceDF.writeTo(tableNameAsString).overwritePartitions() + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged (different partition) + Row(11, 999, "hr"), // overwrote hr partition + Row(12, 888, "hr"))) // overwrote hr partition + } + + test("SQL INSERT INTO SELECT with transactional checks") { + // create table with initial data + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // SQL INSERT INTO using SELECT from the same table (self-insert) + val (txn, txnTables) = executeTransaction { + sql(s"""INSERT INTO $tableNameAsString + |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + + // check data was inserted correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 300, "hr"), + Row(11, 100, "hr"), // inserted from pk=1 + Row(13, 300, "hr"))) // inserted from pk=3 + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala index 30890200df79d..fbcfdfb20c6ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala @@ -109,7 +109,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { collected = df.queryExecution.executedPlan.collect { case CommandResultExec( - _, AppendDataExec(_, _, write), + _, AppendDataExec(_, _, write, _), _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] assert(append.info.options.get("write.split-size") === "10") @@ -141,7 +141,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case AppendDataExec(_, _, write) => + case AppendDataExec(_, _, write, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] assert(append.info.options.get("write.split-size") === "10") } @@ -168,7 +168,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case AppendDataExec(_, _, write) => + case AppendDataExec(_, _, write, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] assert(append.info.options.get("write.split-size") === "10") } @@ -194,7 +194,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { collected = df.queryExecution.executedPlan.collect { case CommandResultExec( - _, OverwriteByExpressionExec(_, _, write), + _, OverwriteByExpressionExec(_, _, write, _), _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") @@ -227,7 +227,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case OverwritePartitionsDynamicExec(_, _, write) => + case OverwritePartitionsDynamicExec(_, _, write, _) => val dynOverwrite = write.toBatch.asInstanceOf[InMemoryBaseTable#DynamicOverwrite] assert(dynOverwrite.info.options.get("write.split-size") === "10") } @@ -254,7 +254,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { collected = df.queryExecution.executedPlan.collect { case CommandResultExec( - _, OverwriteByExpressionExec(_, _, write), + _, OverwriteByExpressionExec(_, _, write, _), _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") @@ -287,7 +287,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case OverwriteByExpressionExec(_, _, write) => + case OverwriteByExpressionExec(_, _, write, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") } @@ -317,7 +317,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case OverwriteByExpressionExec(_, _, write) => + case OverwriteByExpressionExec(_, _, write, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index 0f7f4cefe2feb..e5cf1d2b53975 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.connector -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions.CheckInvariant import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.connector.catalog.{Aborted, Committed} import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec, ReplaceDataExec, WriteDeltaExec} +import org.apache.spark.sql.sources abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { @@ -28,6 +30,10 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { protected def enforceCheckConstraintOnDelete: Boolean = true + // true when the table engine uses delta-based deletes (WriteDeltaExec), false for group-based + // (ReplaceDataExec); controls expected scan counts in transactional tests + protected def deltaDelete: Boolean = false + test("delete from table containing added column with default value") { createAndInitTable("pk INT NOT NULL, dep STRING", """{ "pk": 1, "dep": "hr" }""") @@ -659,6 +665,186 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { } } + test("delete with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val exception = intercept[AnalysisException] { + sql(s"DELETE FROM $tableNameAsString WHERE invalid_column = 1") + } + + assert(exception.getMessage.contains("invalid_column")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("delete with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // simple predicate delete: goes through SupportsDelete.deleteWhere (no Spark-side scan) + val (txn, _) = executeTransaction { + sql(s"DELETE FROM $tableNameAsString WHERE dep = 'hr'") + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(2, 200, "software"))) + } + + test("delete with subquery on source table and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + val (txn, txnTables) = executeTransaction { + sql( + s"""DELETE FROM $tableNameAsString + |WHERE pk IN (SELECT pk FROM $sourceNameAsString WHERE dep = 'hr') + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaDelete) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + val numSubquerySourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numSubquerySourceScans == expectedNumSourceScans) + + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaDelete) 1 else 3 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in subquery result) + } + + test("delete with CTE and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH cte AS ( + | SELECT pk FROM $sourceNameAsString WHERE dep = 'hr' + |) + |DELETE FROM $tableNameAsString + |WHERE pk IN (SELECT pk FROM cte) + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaDelete) 1 else 3 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaDelete) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numCteSourceScans == expectedNumSourceScans) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in source) + } + + test("delete using view with transactional checks") { + withView("temp_view") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + sql( + s"""CREATE VIEW temp_view AS + |SELECT pk FROM $sourceNameAsString WHERE dep = 'hr' + |""".stripMargin) + + val (txn, txnTables) = executeTransaction { + sql(s"DELETE FROM $tableNameAsString WHERE pk IN (SELECT pk FROM temp_view)") + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaDelete) 1 else 3 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaDelete) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in source) + } + } + + test("EXPLAIN DELETE SQL with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"EXPLAIN DELETE FROM $tableNameAsString WHERE dep = 'hr'") + + // EXPLAIN should not start a new transaction + assert(catalog.transaction === null) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) + } + private def executeDeleteWithFilters(query: String): Unit = { val executedPlan = executeAndKeepPlan { sql(query) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala index 9046123ddbd3f..56a689364d1c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala @@ -33,6 +33,8 @@ class DeltaBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { override def enforceCheckConstraintOnDelete: Boolean = false + override protected def deltaDelete: Boolean = true + test("delete handles metadata columns correctly") { createAndInitTable("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala index 89b42b5e6db7b..e821fc3f660da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.{AnalysisException, Row} abstract class DeltaBasedUpdateTableSuiteBase extends UpdateTableSuiteBase { + override protected def deltaUpdate: Boolean = true + test("nullable row ID attrs") { createAndInitTable("pk INT, salary INT, dep STRING", """{ "pk": 1, "salary": 300, "dep": 'hr' } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala index e1c574ec7ba65..687aae91438da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.connector +import org.apache.spark.sql.{sources, Column, Row} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.Row import org.apache.spark.sql.classic.MergeIntoWriter -import org.apache.spark.sql.connector.catalog.Column +import org.apache.spark.sql.connector.catalog.Committed import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.TableInfo import org.apache.spark.sql.functions._ @@ -31,6 +31,71 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { import testImplicits._ + private def targetTableCol(colName: String): Column = { + col(tableNameAsString + "." + colName) + } + + test("self merge with transactional checks") { + // create table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create a source on top of itself that will be fully resolved and analyzed + val sourceDF = spark.table(tableNameAsString) + .where("salary == 100") + .as("source") + sourceDF.queryExecution.assertAnalyzed() + + // merge into table using the source on top of itself + val (txn, txnTables) = executeTransaction { + sourceDF + .mergeInto( + tableNameAsString, + $"source.pk" === targetTableCol("pk") && targetTableCol("dep") === "hr") + .whenMatched() + .update(Map("salary" -> targetTableCol("salary").plus(1))) + .whenNotMatched() + .insertAll() + .merge() + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + + // check all table scans + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size == 4) + + // check table scans as MERGE target + val numTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numTargetScans == 2) + + // check table scans as MERGE source + val numSourceScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("salary", 100) => true + case _ => false + } + assert(numSourceScans == 2) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged + + // TODO Achatzis check version. + } + test("merge into empty table with NOT MATCHED clause") { withTempView("source") { createTable("pk INT NOT NULL, salary INT, dep STRING") @@ -979,6 +1044,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { } test("SPARK-54157: version is refreshed when source is V2 table") { + import org.apache.spark.sql.connector.catalog.Column val sourceTable = "cat.ns1.source_table" withTable(sourceTable) { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -1026,6 +1092,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { } test("SPARK-54444: any schema changes after analysis are prohibited") { + import org.apache.spark.sql.connector.catalog.Column val sourceTable = "cat.ns1.source_table" withTable(sourceTable) { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 7c0e503705c7e..9889b9b53742d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, In, Not} import org.apache.spark.sql.catalyst.optimizer.BuildLeft -import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, InMemoryTable, TableInfo} +import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, InMemoryTable, TableInfo} import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.write.MergeSummary import org.apache.spark.sql.execution.SparkPlan @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.MergeRowsExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources import org.apache.spark.sql.types.{IntegerType, StringType} abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase @@ -38,6 +39,305 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase protected def deltaMerge: Boolean = false + test("self merge with transactional checks") { + // create table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // merge into table using a subquery on top of itself + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING (SELECT * FROM $tableNameAsString WHERE salary = 100) s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = t.salary + 1 + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + + // check all table scans + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumScans = if (deltaMerge) 2 else 4 + assert(targetTxnTable.scanEvents.size == expectedNumScans) + + // check table scans as MERGE target + val numTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumTargetScans = if (deltaMerge) 1 else 2 + assert(numTargetScans == expectedNumTargetScans) + + // check table scans as MERGE source + val numSourceScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("salary", 100) => true + case _ => false + } + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(numSourceScans == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged + } + + test("merge into table with analysis failure and transactional checks") { + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'support'), (4, 400, 'finance')") + + val exception = intercept[AnalysisException] { + sql( + s"""MERGE INTO $tableNameAsString t + |USING $sourceNameAsString s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET salary = s.invalid_column + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + assert(exception.getMessage.contains("invalid_column")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("merge into table using view with transactional checks") { + withView("temp_view") { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)") + sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)") + + // create view on top of source and target tables + sql( + s"""CREATE VIEW temp_view AS + |SELECT s.pk, s.salary, t.dep + |FROM $sourceNameAsString s + |LEFT JOIN ( + | SELECT * FROM $tableNameAsString WHERE pk < 10 + |) t ON s.pk = t.pk + |""".stripMargin) + + // merge into target table using view + val (txn, txnTables) = executeTransaction { + sql(s"""MERGE INTO $tableNameAsString t + |USING temp_view s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, dep = s.dep + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaMerge) 2 else 4 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as MERGE target (dep = 'hr') + val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumMergeTargetScans = if (deltaMerge) 1 else 2 + assert(numMergeTargetScans == expectedNumMergeTargetScans) + + // check target table scans in view as MERGE source (pk < 10) + val numViewTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.LessThan("pk", 10L) => true + case _ => false + } + val expectedNumViewTargetScans = if (deltaMerge) 1 else 2 + assert(numViewTargetScans == expectedNumViewTargetScans) + + // check source table scans in view as MERGE source (no predicate) + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged + Row(4, 400, "pending"))) // new + } + } + + test("merge into table using nested view with transactional checks") { + withView("base_view", "nested_view") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)") + sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)") + + // create base view + sql( + s"""CREATE VIEW base_view AS + |SELECT s.pk, s.salary, t.dep + |FROM $sourceNameAsString s + |LEFT JOIN ( + | SELECT * FROM $tableNameAsString WHERE pk < 10 + |) t ON s.pk = t.pk + |""".stripMargin) + + // create nested view on top of base view + sql( + s"""CREATE VIEW nested_view AS + |SELECT * FROM base_view + |""".stripMargin) + + // merge into target table using nested view + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING nested_view s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, dep = s.dep + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaMerge) 2 else 4 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as MERGE target (dep = 'hr') + val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumMergeTargetScans = if (deltaMerge) 1 else 2 + assert(numMergeTargetScans == expectedNumMergeTargetScans) + + // check target table scans in view as MERGE source (pk < 10) + val numViewTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.LessThan("pk", 10L) => true + case _ => false + } + val expectedNumViewTargetScans = if (deltaMerge) 1 else 2 + assert(numViewTargetScans == expectedNumViewTargetScans) + + // check source table scans in view as MERGE source (no predicate) + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged + Row(4, 400, "pending"))) // new + } + } + } + + test("merge into table rewritten as INSERT with transactional checks") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT, value STRING, dep STRING", + """{ "pk": 1, "value": "a", "dep": "hr" } + |{ "pk": 2, "value": "b", "dep": "finance" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT, value STRING, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (3, 'c', 'hr'), (4, 'd', 'software')") + + // merge into target with only WHEN NOT MATCHED clauses (rewritten as insert) + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING $sourceNameAsString s + |ON t.pk = s.pk + |WHEN NOT MATCHED AND s.pk < 4 THEN + | INSERT (pk, value, dep) VALUES (s.pk, concat(s.value, '_low'), s.dep) + |WHEN NOT MATCHED AND s.pk >= 4 THEN + | INSERT (pk, value, dep) VALUES (s.pk, concat(s.value, '_high'), s.dep) + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size == 1) + + // check source table was scanned correctly + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size == 1) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, "a", "hr"), // unchanged + Row(2, "b", "finance"), // unchanged + Row(3, "c_low", "hr"), // inserted via first NOT MATCHED clause + Row(4, "d_high", "software"))) // inserted via second NOT MATCHED clause + } + } + test("merge into table with expression-based default values") { val columns = Array( Column.create("pk", IntegerType), @@ -671,6 +971,129 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("merge with CTE with transactional checks") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // merge into target table using CTE + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH cte AS ( + | SELECT pk, salary + 50 AS salary, dep + | FROM $sourceNameAsString + | WHERE salary > 100 + |) + |MERGE INTO $tableNameAsString t + |USING cte s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaMerge) 1 else 2 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as MERGE target (dep = 'hr') + val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numMergeTargetScans == expectedNumTargetScans) + + // check source table was scanned correctly + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check source table scans in CTE (salary > 100) + val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.GreaterThan("salary", 100) => true + case _ => false + } + assert(numCteSourceScans == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 200, "hr"), // updated (150 + 50) + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged + Row(4, 450, "pending"))) // inserted (400 + 50) + } + } + + test("merge with cached source and transactional checks") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create and populate source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'support'), (4, 400, 'finance')") + + // Cache source table before the transaction. Make sure when the transation is active the + // catalog still creates a transaction table. + spark.table(sourceNameAsString).cache() + + try { + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING $sourceNameAsString s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, dep = s.dep + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + // both target and source must have been read through the transaction catalog + assert(txnTables.size == 2) + assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) + assert(txnTables(tableNameAsString).scanEvents.nonEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "support"), // matched and updated + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged (no match in source) + Row(4, 400, "pending"))) // not matched, inserted + } finally { + spark.catalog.uncacheTable(sourceNameAsString) + } + } + } + test("merge with subquery as source") { withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -2223,6 +2646,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sql(query) } assert(e.getMessage.contains("ON search condition of the MERGE statement")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) } private def assertMetric( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 8c51fb17b2cf4..73983283bcb8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -28,16 +28,16 @@ import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expr import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReplaceData, WriteDelta} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, TableInfo, Update, Write} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, Table, TableInfo, Txn, TxnTable, Update, Write} import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference} import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.write.RowLevelOperationTable import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType, StructField, StructType} -import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -82,6 +82,7 @@ abstract class RowLevelOperationSuiteBase protected val namespace: Array[String] = Array("ns1") protected val ident: Identifier = Identifier.of(namespace, "test_table") protected val tableNameAsString: String = "cat." + ident.toString + protected val sourceNameAsString: String = "cat.ns1.source_table" protected def extraTableProps: java.util.Map[String, String] = { Collections.emptyMap[String, String] @@ -133,24 +134,36 @@ abstract class RowLevelOperationSuiteBase } } - // executes an operation and keeps the executed plan - protected def executeAndKeepPlan(func: => Unit): SparkPlan = { - var executedPlan: SparkPlan = null - - val listener = new QueryExecutionListener { - override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - executedPlan = qe.executedPlan - } - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { - } + protected def executeTransaction(func: => Unit): (Txn, Map[String, TxnTable]) = { + val qe = execute(func) + val tables = collectWithSubqueries(qe.executedPlan) { + case BatchScanExec(_, _, _, _, table: TxnTable, _) => + table + case BatchScanExec(_, _, _, _, RowLevelOperationTable(table: TxnTable, _), _) => + table } - spark.listenerManager.register(listener) + (catalog.lastTransaction, indexByName(tables)) + } - func + private def indexByName[T <: Table](tables: Seq[T]): Map[String, T] = { + tables.groupBy(_.name).map { + case (name, sameNameTables) => + val Seq(table) = sameNameTables.distinct + name -> table + } + } - sparkContext.listenerBus.waitUntilEmpty() + // executes an operation and keeps the executed plan + protected def executeAndKeepPlan(func: => Unit): SparkPlan = { + val qe = execute(func) + stripAQEPlan(qe.executedPlan) + } - stripAQEPlan(executedPlan) + private def execute(func: => Unit): QueryExecution = { + withQueryExecutionsCaptured(spark)(func) match { + case Seq(qe) => qe + case other => fail(s"expected only one query execution, but got ${other.size}") + } } // executes an operation and extracts conditions from ReplaceData or WriteDelta diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index ac0bf3bdba9ce..0061856ea5da8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.connector import org.apache.spark.SparkRuntimeException -import org.apache.spark.sql.Row -import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, TableChange, TableInfo} +import org.apache.spark.sql.{sources, AnalysisException, Row} +import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, TableChange, TableInfo} import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StringType} @@ -28,6 +28,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { import testImplicits._ + protected def deltaUpdate: Boolean = false + test("update table containing added column with default value") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", """{ "pk": 1, "salary": 100, "dep": "hr" } @@ -762,4 +764,321 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { Row(5))) } } + + test("update with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val exception = intercept[AnalysisException] { + sql(s"UPDATE $tableNameAsString SET invalid_column = -1") + } + + assert(exception.getMessage.contains("invalid_column")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("update with CTE and transactional checks") { + // create table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // update using CTE + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH cte AS ( + | SELECT pk, salary + 50 AS adjusted_salary, dep + | FROM $sourceNameAsString + | WHERE salary > 100 + |) + |UPDATE $tableNameAsString t + |SET salary = -1 + |WHERE t.dep = 'hr' AND EXISTS (SELECT 1 FROM cte WHERE cte.pk = t.pk) + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaUpdate) 1 else 3 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans for UPDATE condition (dep = 'hr') + val numUpdateTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numUpdateTargetScans == expectedNumTargetScans) + + // check source table was scanned correctly + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaUpdate) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check source table scans in CTE (salary > 100) + val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.GreaterThan("salary", 100) => true + case _ => false + } + assert(numCteSourceScans == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (no matching pk in source) + } + + test("update with subquery on source table and transactional checks") { + // create target table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // update using an uncorrelated IN subquery that reads from a transactional catalog table + val (txn, txnTables) = executeTransaction { + sql( + s"""UPDATE $tableNameAsString + |SET salary = -1 + |WHERE pk IN (SELECT pk FROM $sourceNameAsString WHERE dep = 'hr') + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check source table was scanned correctly (dep = 'hr' filter in the subquery) + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaUpdate) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + val numSubquerySourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numSubquerySourceScans == expectedNumSourceScans) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaUpdate) 1 else 3 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated (pk 1 is in subquery result) + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in subquery result) + } + + test("update with uncorrelated scalar subquery on source table and transactional checks") { + // create target table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // update using an uncorrelated scalar subquery in the SET clause that reads from a + // transactional catalog table; scalar subqueries are executed as SubqueryExec at runtime + // and cannot be rewritten as joins + val (txn, txnTables) = executeTransaction { + sql( + s"""UPDATE $tableNameAsString + |SET salary = (SELECT max(salary) FROM $sourceNameAsString WHERE dep = 'hr') + |WHERE dep = 'hr' + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check source table was scanned via the transaction catalog + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.nonEmpty) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + // check target table was scanned via the transaction catalog + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.nonEmpty) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "hr"), // updated (max salary in source for 'hr' is 150) + Row(2, 200, "software"), // unchanged + Row(3, 150, "hr"))) // updated + } + + test("update with constraint violation and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val exception = intercept[SparkRuntimeException] { + executeTransaction { + sql( + s"""UPDATE $tableNameAsString + |SET pk = NULL + |WHERE dep = 'hr' + |""".stripMargin) // NULL violates NOT NULL constraint + } + } + + assert(exception.getMessage.contains("NOT_NULL_ASSERT_VIOLATION")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("update using view with transactional checks") { + withView("temp_view") { + // create target table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)") + sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)") + + // create view on top of source and target tables + sql( + s"""CREATE VIEW temp_view AS + |SELECT s.pk, s.salary, t.dep + |FROM $sourceNameAsString s + |LEFT JOIN ( + | SELECT * FROM $tableNameAsString WHERE pk < 10 + |) t ON s.pk = t.pk + |""".stripMargin) + + // update target table using view + val (txn, txnTables) = executeTransaction { + sql( + s"""UPDATE $tableNameAsString t + |SET salary = -1 + |WHERE t.dep = 'hr' AND EXISTS (SELECT 1 FROM temp_view v WHERE v.pk = t.pk) + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaUpdate) 2 else 7 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as UPDATE target (dep = 'hr') + val numUpdateTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumUpdateTargetScans = if (deltaUpdate) 1 else 3 + assert(numUpdateTargetScans == expectedNumUpdateTargetScans) + + // check target table scans in view as source (pk < 10) + val numViewTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.LessThan("pk", 10L) => true + case _ => false + } + val expectedNumViewTargetScans = if (deltaUpdate) 1 else 4 + assert(numViewTargetScans == expectedNumViewTargetScans) + + // check source table scans in view + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaUpdate) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated from view + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (no matching pk in source) + } + } + + test("df.explain() on update with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // NOTE: df.explain() on a DML command actually executes the write. + // TODO(achatzis): This is existing behavior but check why this is OK. Shouldn't sql() be lazy? + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'").explain() + + assert(catalog.lastTransaction != null) + assert(catalog.lastTransaction.currentState == Committed) + assert(catalog.lastTransaction.isClosed) + + // the UPDATE was actually executed, not just planned + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated + Row(2, 200, "software"))) // unchanged + } + + test("EXPLAIN UPDATE SQL with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // EXPLAIN UPDATE only plans the command, it does not execute the write. + sql(s"EXPLAIN UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + // A transaction should not have started at all. + assert(catalog.transaction === null) + + // The UPDATE was not executed. Data is unchanged. + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala new file mode 100644 index 0000000000000..141d5966b4b6c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import scala.concurrent.duration._ + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.classic +import org.apache.spark.sql.execution.QueryExecution + +/** + * Benchmark to measure the overhead of cloning the analyzer for transactional query execution. + * Each transactional query creates a new [[Analyzer]] instance via + * [[Analyzer.withCatalogManager]], which shares all rules with the original but carries a + * per-query [[org.apache.spark.sql.connector.catalog.CatalogManager]]. This benchmark checks + * whether the cloning introduces measurable overhead. + * + * To run this benchmark: + * {{{ + * build/sbt "sql/Test/runMain " + * }}} + */ +object AnalyzerBenchmark extends SqlBasedBenchmark { + + private val numRows = 100 + private val queries = Seq( + "simple select" -> "SELECT id, val FROM t1", + "join" -> "SELECT t1.id, t2.val FROM t1 JOIN t2 ON t1.id = t2.id", + "wide schema" -> s"SELECT ${(1 to 100).map(i => s"col_$i").mkString(", ")} FROM wide_t" + ) + + private def setupTables(): Unit = { + spark.range(numRows).selectExpr("id", "id * 2 as val").createOrReplaceTempView("t1") + spark.range(numRows).selectExpr("id", "id * 3 as val").createOrReplaceTempView("t2") + spark.range(numRows) + .selectExpr((1 to numRows).map(i => s"id as col_$i"): _*) + .createOrReplaceTempView("wide_t") + } + + /** + * Measures analysis time for a pre-parsed plan, comparing the session analyzer against a + * cloned analyzer created via [[Analyzer.withCatalogManager]]. + * + * Two cases: + * - "session analyzer" : baseline, uses the session analyzer directly. + * - "cloned analyzer (per query)": analyzer cloned every iteration; reflects the full + * per-transactional-query cost (clone + analysis). + */ + def analysisBenchmark(): Unit = { + for ((name, sql) <- queries) { + runBenchmark(s"analysis overhead $name") { + val benchmark = new Benchmark( + name = s"analysis overhead $name", + // Per row measurements are not meaningful here. + valuesPerIteration = numRows, + minTime = 10.seconds, + output = output) + val catalogManager = spark.sessionState.catalogManager + + benchmark.addCase("session analyzer") { _ => + val plan = spark.sessionState.sqlParser.parsePlan(sql) + new QueryExecution(spark.asInstanceOf[classic.SparkSession], plan).analyzed + } + + benchmark.addCase("cloned analyzer (per query)") { _ => + val cloned = spark.sessionState.analyzer.withCatalogManager(catalogManager) + val plan = spark.sessionState.sqlParser.parsePlan(sql) + new QueryExecution(spark.asInstanceOf[classic.SparkSession], + plan, analyzerOpt = Some(cloned)).analyzed + } + + benchmark.run() + } + } + } + + /** + * Micro-benchmark for [[Analyzer.withCatalogManager]] in isolation: measures the cost of + * instantiating the anonymous [[Analyzer]] subclass, independent of analysis work. + */ + def cloneCostBenchmark(): Unit = { + runBenchmark("analyzer clone cost") { + val numRows = 1 // Per row measurements are not meaningful here. + val benchmark = new Benchmark( + name = "analyzer clone cost", + valuesPerIteration = numRows, + output = output) + val catalogManager = spark.sessionState.catalogManager + + benchmark.addCase("withCatalogManager") { _ => + spark.sessionState.analyzer.withCatalogManager(catalogManager) + } + + benchmark.run() + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + setupTables() + cloneCostBenchmark() + analysisBenchmark() + } +} From a6a1eb70f207e9236354720327af06f1a6d2e538 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Thu, 2 Apr 2026 13:52:01 +0000 Subject: [PATCH 02/22] Fix delete failures --- .../sql/connector/DeleteFromTableSuiteBase.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index e5cf1d2b53975..3d188af4701a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -722,7 +722,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(txnTables.size == 2) val sourceTxnTable = txnTables(sourceNameAsString) - val expectedNumSourceScans = if (deltaDelete) 1 else 4 + val expectedNumSourceScans = if (deltaDelete) 1 else 2 assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) val numSubquerySourceScans = sourceTxnTable.scanEvents.flatten.count { @@ -732,7 +732,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(numSubquerySourceScans == expectedNumSourceScans) val targetTxnTable = txnTables(tableNameAsString) - val expectedNumTargetScans = if (deltaDelete) 1 else 3 + val expectedNumTargetScans = if (deltaDelete) 1 else 2 assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) checkAnswer( @@ -767,11 +767,11 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(txnTables.size == 2) val targetTxnTable = txnTables(tableNameAsString) - val expectedNumTargetScans = if (deltaDelete) 1 else 3 + val expectedNumTargetScans = if (deltaDelete) 1 else 2 assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) val sourceTxnTable = txnTables(sourceNameAsString) - val expectedNumSourceScans = if (deltaDelete) 1 else 4 + val expectedNumSourceScans = if (deltaDelete) 1 else 2 assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count { @@ -812,11 +812,11 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(txnTables.size == 2) val targetTxnTable = txnTables(tableNameAsString) - val expectedNumTargetScans = if (deltaDelete) 1 else 3 + val expectedNumTargetScans = if (deltaDelete) 1 else 2 assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) val sourceTxnTable = txnTables(sourceNameAsString) - val expectedNumSourceScans = if (deltaDelete) 1 else 4 + val expectedNumSourceScans = if (deltaDelete) 1 else 2 assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) checkAnswer( From 4e82fbf7db03950044e792b256f4348b88faea75 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Thu, 9 Apr 2026 19:05:49 +0000 Subject: [PATCH 03/22] Cleaning pass 1 --- .../sql/catalyst/analysis/Analyzer.scala | 12 ++++---- .../analysis/RelationResolution.scala | 11 +++++-- .../UnresolveTransactionRelations.scala | 12 ++++++-- .../catalyst/analysis/V2TableReference.scala | 18 +++++------- .../catalyst/plans/logical/statements.scala | 4 +++ .../catalyst/plans/logical/v2Commands.scala | 20 +++++++++---- .../sql/connector/catalog/LookupCatalog.scala | 4 ++- ...nMemoryRowLevelOperationTableCatalog.scala | 7 +++-- .../sql/connector/catalog/InMemoryTable.scala | 5 ++-- .../catalog/InMemoryTableCatalog.scala | 4 --- .../spark/sql/connector/catalog/txns.scala | 29 +++++++++++-------- .../apache/spark/sql/classic/Catalog.scala | 17 ----------- .../spark/sql/execution/CacheManager.scala | 3 -- 13 files changed, 76 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9c3bc2c29ec56..09928553839ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -339,9 +339,8 @@ class Analyzer( /** * Returns a copy of this analyzer that uses the given [[CatalogManager]] for all catalog * lookups. All other configuration (extended rules, checks, etc.) is preserved. Used by - * [[QueryExecution]] to create a per-query analyzer for transactional queries so that - * transaction-aware catalog resolution is an instance-level property rather than thread-local - * state. + * [[QueryExecution]] to create a per-query analyzer for transactional operations for + * transaction-aware catalog resolution. */ def withCatalogManager(newCatalogManager: CatalogManager): Analyzer = { val self = this @@ -1032,9 +1031,10 @@ class Analyzer( } } - // Resolve V2TableReference nodes in a plan. V2TableReference is only created for temp views - // (via V2TableReference.createForTempView), so we only need to resolve it when returning - // the plan of temp views (in resolveViews and unwrapRelationPlan). + // Resolve V2TableReference nodes created for: + // 1 Temp views (via createForTempView). + // 2. Transaction references (via createForTransaction). These are resolved by a + // separate analysis batch in the transaction-aware analyzer instance. private def resolveTableReferences(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsUp { case r: V2TableReference => relationResolution.resolveReference(r) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala index 1f78ca3371f59..feceec60488a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -410,13 +410,18 @@ class RelationResolution( } private def loadRelation(ref: V2TableReference): LogicalPlan = { + // Resolve catalog. When a transaction is active we return the transaction + // aware catalog instance. val resolvedCatalog = catalogManager.catalog(ref.catalog.name).asTableCatalog val table = resolvedCatalog.loadTable(ref.identifier) - // val table = ref.catalog.loadTable(ref.identifier) V2TableReferenceUtils.validateLoadedTable(table, ref) - // ref.toRelation(table) + // Create relation with resolved Catalog. DataSourceV2Relation( - table, ref.output, Some(resolvedCatalog), Some(ref.identifier), ref.options) + table = table, + output = ref.output, + catalog = Some(resolvedCatalog), + identifier = Some(ref.identifier), + options = ref.options) } private def adaptCachedRelation(cached: LogicalPlan, ref: V2TableReference): LogicalPlan = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala index 4b175dd44ef08..0e344173d7892 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala @@ -23,6 +23,15 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +/** + * When a transaction is active, converts resolved [[DataSourceV2Relation]] nodes back to + * [[V2TableReference]] placeholders for all relations loaded by a catalog with the same + * name as the transaction catalog. + * + * This forces re-resolution of those relations against the transaction's catalog, which + * intercepts [[TableCatalog#loadTable]] calls to track which tables are read as part of + * the transaction. + */ class UnresolveTransactionRelations(val catalogManager: CatalogManager) extends Rule[LogicalPlan] with LookupCatalog { @@ -43,14 +52,13 @@ class UnresolveTransactionRelations(val catalogManager: CatalogManager) catalog: CatalogPlugin): LogicalPlan = { plan transform { case r: DataSourceV2Relation if isLoadedFromCatalog(r, catalog) => - V2TableReference.createForRelation(r, Seq.empty) + V2TableReference.createForTransaction(r) } } private def isLoadedFromCatalog( relation: DataSourceV2Relation, catalog: CatalogPlugin): Boolean = { - // relation.catalog.exists(_ eq catalog) && relation.identifier.isDefined relation.catalog.exists(_.name == catalog.name) && relation.identifier.isDefined } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala index a2379f33e14ee..76226056ffe65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.V2TableReference.Context import org.apache.spark.sql.catalyst.analysis.V2TableReference.TableInfo import org.apache.spark.sql.catalyst.analysis.V2TableReference.TemporaryViewContext -import org.apache.spark.sql.catalyst.analysis.V2TableReference.TestContext +import org.apache.spark.sql.catalyst.analysis.V2TableReference.TransactionContext import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics @@ -85,17 +85,15 @@ private[sql] object V2TableReference { sealed trait Context case class TemporaryViewContext(viewName: Seq[String]) extends Context - // TODO(achatzis): Fix naming and complete implementation. - case class TestContext(tableName: Seq[String]) extends Context + /** Context for relations that are re-resolved through a transaction catalog. */ + case object TransactionContext extends Context def createForTempView(relation: DataSourceV2Relation, viewName: Seq[String]): V2TableReference = { create(relation, TemporaryViewContext(viewName)) } - def createForRelation( - relation: DataSourceV2Relation, - relationName: Seq[String]): V2TableReference = { - create(relation, TestContext(relationName)) + def createForTransaction(relation: DataSourceV2Relation): V2TableReference = { + create(relation, TransactionContext) } private def create(relation: DataSourceV2Relation, context: Context): V2TableReference = { @@ -119,7 +117,7 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { ref.context match { case ctx: TemporaryViewContext => validateLoadedTableInTempView(table, ref, ctx) - case _: TestContext => + case TransactionContext => validateLoadedTableInTransaction(table, ref) case ctx => throw SparkException.internalError(s"Unknown table ref context: ${ctx.getClass.getName}") @@ -128,8 +126,8 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { private def validateLoadedTableInTransaction(table: Table, ref: V2TableReference): Unit = { val dataErrors = V2TableUtil.validateCapturedColumns( - table, - ref.info.columns, + table = table, + originCols = ref.info.columns, mode = PROHIBIT_CHANGES) if (dataErrors.nonEmpty) { throw QueryCompilationErrors.columnsChangedAfterAnalysis(ref.name, dataErrors) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index fb54af2344d1b..774c783ecf8a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -188,6 +188,10 @@ case class InsertIntoStatement( byName: Boolean = false, replaceCriteriaOpt: Option[InsertReplaceCriteria] = None, withSchemaEvolution: Boolean = false) + // Extends TransactionalWrite so that QueryExecution can detect a potential transaction on the + // unresolved logical plan before analysis runs. InsertIntoStatement is shared between V1 and V2 + // inserts, but the LookupCatalog.TransactionalWrite extractor only matches when the target + // catalog implements TransactionalCatalogPlugin, so V1 inserts are never assigned a transaction. extends UnaryParsedStatement with TransactionalWrite { require(overwrite || !ifPartitionNotExists, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 5406c5d6a35fc..f9e7774e1448d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -520,8 +520,10 @@ case class WriteDelta( trait V2CreateTableAsSelectPlan extends V2CreateTablePlan with AnalysisOnlyCommand - with CTEInChildren { + with CTEInChildren + with TransactionalWrite { def query: LogicalPlan + override def table: LogicalPlan = name override def withCTEDefs(cteDefs: Seq[CTERelationDef]): LogicalPlan = { withNameAndQuery(newName = name, newQuery = WithCTE(query, cteDefs)) @@ -1000,13 +1002,13 @@ case class MergeIntoTable( notMatchedActions: Seq[MergeAction], notMatchedBySourceActions: Seq[MergeAction], withSchemaEvolution: Boolean) - extends BinaryCommand - with WriteWithSchemaEvolution - with TransactionalWrite - with SupportsSubquery { + extends BinaryCommand + with WriteWithSchemaEvolution + with SupportsSubquery + with TransactionalWrite { // Implements SupportsSchemaEvolution.table. - // Implements TransactionalWrite.table, identifying the MERGE target as the table being written. + // Implements TransactionalWrite.table. override val table: LogicalPlan = EliminateSubqueryAliases(targetTable) override def withNewTable(newTable: NamedRelation): MergeIntoTable = { @@ -1277,6 +1279,12 @@ case class Assignment(key: Expression, value: Expression) extends Expression newLeft: Expression, newRight: Expression): Assignment = copy(key = newLeft, value = newRight) } +/** + * Marker trait for write operations that participate in a DSv2 transaction. + * + * Implementations are expected to target a DSv2 catalog backed by a + * [[org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin]]. + */ trait TransactionalWrite extends LogicalPlan { def table: LogicalPlan } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala index fbb2938fd3da2..dd5be45bfc5f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connector.catalog import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.plans.logical.TransactionalWrite import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -171,6 +171,8 @@ private[sql] trait LookupCatalog extends Logging { EliminateSubqueryAliases(write.table) match { case UnresolvedRelation(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _, _) => Some(c) + case UnresolvedIdentifier(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _) => + Some(c) case _ => None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index 27231447a1273..7ba1e9747f52e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -21,17 +21,18 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} class InMemoryRowLevelOperationTableCatalog - extends InMemoryTableCatalog with TransactionalCatalogPlugin { + extends InMemoryTableCatalog + with TransactionalCatalogPlugin { import CatalogV2Implicits._ + // The current active transaction. var transaction: Txn = _ - // Tracks the last completed transaction for test assertions; cleared when a new one begins. + // The last completed transaction. var lastTransaction: Txn = _ override def beginTransaction(info: TransactionInfo): Transaction = { assert(transaction == null || transaction.currentState != Active) this.transaction = new Txn(new TxnTableCatalog(this)) - this.lastTransaction = transaction transaction } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 2f3c65924d6a7..15ed4136dbda8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -216,9 +216,8 @@ class InMemoryTable( object InMemoryTable { - // V1 filter values (from PredicateUtils.toV1) are Scala types (e.g. String), but partition - // keys stored in dataMap are Catalyst internal types (e.g. UTF8String). Normalize both sides - // before comparing so that string partitions work correctly. + // Convert UTF8String to string to make sure equality checks between filters and partitions + // work correctly. private def valuesEqual(filterValue: Any, partitionValue: Any): Boolean = (filterValue, partitionValue) match { case (s: String, u: UTF8String) => u.toString == s diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index c7195b512b8d6..ff7995ad6697e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -59,10 +59,6 @@ class BasicInMemoryTableCatalog extends TableCatalog { tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray } - def loadTableAs[T <: Table](ident: Identifier): T = { - loadTable(ident).asInstanceOf[T] - } - // load table for scans override def loadTable(ident: Identifier): Table = { Option(tables.get(ident)) match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 4feb89c78f56c..f4f56d59f7851 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -42,13 +42,13 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction { override def commit(): Unit = { if (closed) throw new IllegalStateException("Can't commit, already closed") + if (state == Aborted) throw new IllegalStateException("Can't commit, already aborted") catalog.commit() this.state = Committed } override def abort(): Unit = { if (state == Committed || state == Aborted) return - // if (closed) throw new IllegalStateException("Can't abort, already closed") this.state = Aborted } @@ -58,9 +58,9 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction { } } -// a special table used in row-level operation transactions -// it inherits data from the base table upon construction and -// propagates staged transaction state back after an explicit commit +// A special table used in row-level operation transactions. It inherits data +// from the base table upon construction and propagates staged transaction state +// back after an explicit commit. class TxnTable(val delegate: InMemoryRowLevelOperationTable) extends InMemoryRowLevelOperationTable( delegate.name, @@ -72,8 +72,7 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) // TODO(achatzis): Rethink how schema evolution works on top of transactions. alterTableWithData(delegate.data, schema) - // a tracker of filters used in each scan - // achatzis: Non-deterministic filters? + // A tracker of filters used in each scan. val scanEvents = new ArrayBuffer[Array[Filter]]() override protected def recordScanEvent(filters: Array[Filter]): Unit = { @@ -92,8 +91,8 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) } } -// a special table catalog used in row-level operation transactions -// table changes are initially staged in memory and propagated only after an explicit commit +// A special table catalog used in row-level operation transactions. +// Table changes are initially staged in memory and propagated only after an explicit commit. class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends TableCatalog { private val tables: util.Map[Identifier, TxnTable] = new ConcurrentHashMap[Identifier, TxnTable]() @@ -108,20 +107,25 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T override def loadTable(ident: Identifier): Table = { tables.computeIfAbsent(ident, _ => { - val table = delegate.loadTableAs[InMemoryRowLevelOperationTable](ident) + val table = delegate.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] new TxnTable(table) }) } + override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { + delegate.createTable(ident, tableInfo) + loadTable(ident) + } + override def alterTable(ident: Identifier, changes: TableChange*): Table = { val newDelegateTable = delegate.alterTable(ident, changes: _*) - // Compute again if absent. - tables.remove(ident) + tables.remove(ident) // Load again. newDelegateTable } override def dropTable(ident: Identifier): Boolean = { - throw new UnsupportedOperationException() + tables.remove(ident) + delegate.dropTable(ident) } override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { @@ -133,6 +137,7 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T } def clearActiveTransaction(): Unit = { + delegate.lastTransaction = delegate.transaction delegate.transaction = null } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala index 81fc534423d8c..35041feca9e18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala @@ -931,23 +931,6 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog with Logging { // caches referencing this relation. If this relation is cached as an InMemoryRelation, // this will clear the relation cache and caches of all its dependents. CommandUtils.recacheTableOrView(sparkSession, relation) - /* - EliminateSubqueryAliases(relation) match { - case r @ ExtractV2CatalogAndIdentifier(catalog, ident) if r.timeTravelSpec.isEmpty => - val nameParts = ident.toQualifiedNameParts(catalog) - sparkSession.sharedState.cacheManager.recacheTableOrView(sparkSession, nameParts) - case _ => - sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, relation) - */ - /* - relation match { - case r: DataSourceV2Relation if r.catalog.isDefined && r.identifier.isDefined => - val nameParts = r.identifier.get.toQualifiedNameParts(r.catalog.get) - sparkSession.sharedState.cacheManager.recacheTableOrView(sparkSession, nameParts) - case _ => - sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, relation) - } - */ } private def resolveRelation(tableName: String): LogicalPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 66f406d39f263..3f92f24156d3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -260,9 +260,6 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { val nameInCache = v2Ident.toQualifiedNameParts(catalog) isSameName(name, nameInCache, resolver) && (includeTimeTravel || timeTravelSpec.isEmpty) - // case r: TableReference => - // isSameName(name, r.identifier.toQualifiedNameParts(r.catalog), resolver) - case v: View => isSameName(name, v.desc.identifier.nameParts, resolver) From 4f8d0a05e95dee203afd154ff87e3ab8fffe5dd9 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Tue, 7 Apr 2026 12:06:10 +0000 Subject: [PATCH 04/22] CTAS/RTAS support plus more cleaning --- .../analysis/RelationResolution.scala | 2 - .../TransactionAwareCatalogManager.scala | 7 +- ...nMemoryRowLevelOperationTableCatalog.scala | 4 +- .../spark/sql/connector/catalog/txns.scala | 4 +- .../spark/sql/execution/QueryExecution.scala | 49 +++--- .../datasources/v2/DataSourceV2Strategy.scala | 9 +- .../v2/WriteToDataSourceV2Exec.scala | 47 ++++-- .../sql/connector/CTASTransactionSuite.scala | 140 ++++++++++++++++++ .../RowLevelOperationSuiteBase.scala | 13 +- 9 files changed, 230 insertions(+), 45 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala index feceec60488a1..fd0394098b0ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -384,8 +384,6 @@ class RelationResolution( } } - // TODO: how to validate the output is compatible? - // TODO: what shall we do if the output mismatches (schema changes?) def resolveReference(ref: V2TableReference): LogicalPlan = { val relation = getOrLoadRelation(ref) val planId = ref.getTagValue(LogicalPlan.PLAN_ID_TAG) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala index 9403219f596da..aaeef4c2dea76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.catalog +import org.apache.spark.sql.catalyst.catalog.TempVariableManager import org.apache.spark.sql.connector.catalog.transactions.Transaction /** @@ -25,13 +26,15 @@ import org.apache.spark.sql.connector.catalog.transactions.Transaction * All mutable state (current catalog, current namespace, loaded catalogs) is delegated to the * wrapped [[CatalogManager]]. */ -// TODO: Consider extracting a CatalogManager trait that both the real -// implementation and the decorator implement +// TODO: Extracting a CatalogManager trait (so this class can implement it instead of extending +// CatalogManager) would eliminate the inherited mutable state that this decorator doesn't use. private[sql] class TransactionAwareCatalogManager( delegate: CatalogManager, txn: Transaction) extends CatalogManager(delegate.defaultSessionCatalog, delegate.v1SessionCatalog) { + override val tempVariableManager: TempVariableManager = delegate.tempVariableManager + override def transaction: Option[Transaction] = Some(txn) override def catalog(name: String): CatalogPlugin = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index 7ba1e9747f52e..78c350b5145a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -21,8 +21,8 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} class InMemoryRowLevelOperationTableCatalog - extends InMemoryTableCatalog - with TransactionalCatalogPlugin { + extends InMemoryTableCatalog + with TransactionalCatalogPlugin { import CatalogV2Implicits._ // The current active transaction. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index f4f56d59f7851..e76fdf97c6888 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -69,7 +69,7 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) delegate.properties, delegate.constraints) { - // TODO(achatzis): Rethink how schema evolution works on top of transactions. + // TODO: Revise schema evolution. alterTableWithData(delegate.data, schema) // A tracker of filters used in each scan. @@ -81,7 +81,7 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) def commit(): Unit = { delegate.dataMap.clear() - // TODO(achatzis): Rethink how schema evolution works on top of transactions. + // TODO: Revise schema evolution. delegate.alterTableWithData(data, delegate.schema) delegate.replacedPartitions = replacedPartitions delegate.lastWriteInfo = lastWriteInfo diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 32ffbfc32acb3..638419f900b34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -99,8 +99,8 @@ class QueryExecution( // 1. At the pre-Analyzed plan we look for nodes that implement the TransactionalWrite trait. // When a plan contains such a node we initiate a transaction. Note, we should never start // a transaction for operations that are not executed, e.g. EXPLAIN. - // 2. Create an analyzer clone with a transaction aware Catalog Manager. The latter is the single - // choke point of all catalog access, and it is also the transaction context carrier. + // 2. Create an analyzer clone with a transaction aware Catalog Manager. The latter is the + // narrow waist of all catalog accesses, and it is also the transaction context carrier. // This is then passed to all rules during analysis that need to check the catalog. Rules // that are specifically interested in transactionality can access the transaction directly // from the Catalog Manager. The transaction catalog, is potentially the place where connectors @@ -280,15 +280,18 @@ class QueryExecution( sparkSession.withActive { assertAnalyzed() assertSupported() - // clone the plan to avoid sharing the plan instance between different stages like analyzing, - // optimizing and planning. - val plan = normalized.clone() - // During a transaction, skip cache substitution. useCachedData replaces DataSourceV2Relation - // nodes (loaded via the transaction catalog) with InMemoryRelation, which bypasses read - // tracking in the transaction catalog and may serve stale data. - // if (transactionOpt.isDefined) plan - // else sparkSession.sharedState.cacheManager.useCachedData(plan) - sparkSession.sharedState.cacheManager.useCachedData(plan) + + // During a transaction, skip cache substitution. This is to avoid replacing relations + // loaded by the transactional catalog with potentially stale relations cached before + // the transaction was active. + val plan = if (transactionOpt.isDefined) { + plan + } + else { + // clone the plan to avoid sharing the plan instance between different stages like + // analyzing, optimizing and planning. + sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) + } } } @@ -331,11 +334,7 @@ class QueryExecution( // plan. QueryExecution.createSparkPlan(planner, optimizedPlan.clone()) } - transactionOpt match { - case Some(txn) => - plan.transformDown { case w: TransactionalExec => w.withTransaction(Some(txn)) } - case None => plan - } + attachTransaction(plan) } def sparkPlan: SparkPlan = executeWithTransactionContext { @@ -586,14 +585,11 @@ class QueryExecution( } /** - * Execute the given block with the transaction context if exists. If there is an exception thrown - * during the execution, the transaction will be aborted. + * Executes the given block with the transaction context if exists. If there is an exception + * thrown during the execution, the transaction will be aborted. * - * Note 1: The transaction is not committed in this method. The caller should commit the + * Note: The transaction is not committed in this method. The caller should commit the * transaction if the execution is successful. - * - * Note 2: In some cases, post commit execution might generate an exception. The abort operation - * should be no-op in this case. */ private def executeWithTransactionContext[T](block: => T): T = transactionOpt match { case Some(transaction) => @@ -602,6 +598,15 @@ class QueryExecution( case None => block } + + /** Attaches a transaction to the given SparkPlan to the transactional execution nodes. */ + private def attachTransaction(plan: SparkPlan): SparkPlan = transactionOpt match { + case Some(txn) => plan.transformDown { + case w: TransactionalExec => w.withTransaction(Some(txn)) + } + case None => plan + } + /** A special namespace for commands that can be used to debug query execution. */ // scalastyle:off object debug { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 82fd732e9de8d..3d3b4d1cae11c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -62,6 +62,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat private def hadoopConf = session.sessionState.newHadoopConf() + // recaches all cache entries without time travel for the given table + // after a write operation that moves the state of the table forward (e.g. append, overwrite) + // this method accounts for V2 tables loaded via TableProvider (no catalog/identifier) private def refreshCache(r: DataSourceV2Relation)(): Unit = r match { case ExtractV2CatalogAndIdentifier(catalog, ident) => val nameParts = ident.toQualifiedNameParts(catalog) @@ -378,14 +381,12 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, _, Some(write)) => // use the original relation to refresh the cache - ReplaceDataExec( - planLater(query), refreshCache(r), projections, write) :: Nil + ReplaceDataExec(planLater(query), refreshCache(r), projections, write) :: Nil case WriteDelta(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, Some(write)) => // use the original relation to refresh the cache - WriteDeltaExec( - planLater(query), refreshCache(r), projections, write) :: Nil + WriteDeltaExec(planLater(query), refreshCache(r), projections, write) :: Nil case MergeRows(isSourceRowPresent, isTargetRowPresent, matchedInstructions, notMatchedInstructions, notMatchedBySourceInstructions, checkCardinality, output, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index a548749a972e0..b860dc3347f9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -75,7 +75,12 @@ case class CreateTableAsSelectExec( query: LogicalPlan, tableSpec: TableSpec, writeOptions: Map[String, String], - ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec { + ifNotExists: Boolean, + transaction: Option[Transaction] = None) + extends V2CreateTableAsSelectBaseExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): CreateTableAsSelectExec = + copy(transaction = txn) val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -93,7 +98,9 @@ case class CreateTableAsSelectExec( .build() val table = Option(catalog.createTable(ident, tableInfo)) .getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - writeToTable(catalog, table, writeOptions, ident, query, overwrite = false) + val result = writeToTable(catalog, table, writeOptions, ident, query, overwrite = false) + transaction.foreach(TransactionUtils.commit) + result } } @@ -113,7 +120,13 @@ case class AtomicCreateTableAsSelectExec( query: LogicalPlan, tableSpec: TableSpec, writeOptions: Map[String, String], - ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec { + ifNotExists: Boolean, + transaction: Option[Transaction] = None) + extends V2CreateTableAsSelectBaseExec + with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): AtomicCreateTableAsSelectExec = + copy(transaction = txn) val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -134,7 +147,9 @@ case class AtomicCreateTableAsSelectExec( .build() val stagedTable = Option(catalog.stageCreate(ident, tableInfo) ).getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - writeToTable(catalog, stagedTable, writeOptions, ident, query, overwrite = false) + val result = writeToTable(catalog, stagedTable, writeOptions, ident, query, overwrite = false) + transaction.foreach(TransactionUtils.commit) + result } } @@ -156,8 +171,12 @@ case class ReplaceTableAsSelectExec( tableSpec: TableSpec, writeOptions: Map[String, String], orCreate: Boolean, - invalidateCache: (TableCatalog, Identifier) => Unit) - extends V2CreateTableAsSelectBaseExec { + invalidateCache: (TableCatalog, Identifier) => Unit, + transaction: Option[Transaction] = None) + extends V2CreateTableAsSelectBaseExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): ReplaceTableAsSelectExec = + copy(transaction = txn) val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -193,9 +212,11 @@ case class ReplaceTableAsSelectExec( .build() val table = Option(catalog.createTable(ident, tableInfo)) .getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - writeToTable( + val result = writeToTable( catalog, table, writeOptions, ident, refreshedQuery, overwrite = true, refreshPhaseEnabled = false) + transaction.foreach(TransactionUtils.commit) + result } } @@ -219,8 +240,12 @@ case class AtomicReplaceTableAsSelectExec( tableSpec: TableSpec, writeOptions: Map[String, String], orCreate: Boolean, - invalidateCache: (TableCatalog, Identifier) => Unit) - extends V2CreateTableAsSelectBaseExec { + invalidateCache: (TableCatalog, Identifier) => Unit, + transaction: Option[Transaction] = None) + extends V2CreateTableAsSelectBaseExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): AtomicReplaceTableAsSelectExec = + copy(transaction = txn) val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -261,7 +286,9 @@ case class AtomicReplaceTableAsSelectExec( } val table = Option(staged).getOrElse( catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - writeToTable(catalog, table, writeOptions, ident, query, overwrite = true) + val result = writeToTable(catalog, table, writeOptions, ident, query, overwrite = true) + transaction.foreach(TransactionUtils.commit) + result } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala new file mode 100644 index 0000000000000..1643fa6879525 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.Committed + +class CTASTransactionSuite extends RowLevelOperationSuiteBase { + + private val newTableNameAsString = "cat.ns1.new_table" + + test("CTAS with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val (txn, txnTables) = executeTransactionMultiQE { + sql(s"""CREATE TABLE $newTableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + + val sourceTxnTable = txnTables(tableNameAsString) + assert(sourceTxnTable.scanEvents.size >= 1) + + checkAnswer( + sql(s"SELECT * FROM $newTableNameAsString"), + Seq(Row(1, 100, "hr"))) + } + + test("CTAS with cached source and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // cache the source table before running CTAS + spark.catalog.cacheTable(tableNameAsString) + + try { + val (txn, txnTables) = executeTransactionMultiQE { + sql(s"""CREATE TABLE $newTableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + + // cache miss: TxnTable-based relation is not structurally equal to the cached one, + // so the scan goes through the transaction catalog and scan events are captured + val sourceTxnTable = txnTables(tableNameAsString) + assert(sourceTxnTable.scanEvents.size >= 1) + + checkAnswer( + sql(s"SELECT * FROM $newTableNameAsString"), + Seq(Row(1, 100, "hr"))) + } finally { + spark.catalog.uncacheTable(tableNameAsString) + } + } + + test("RTAS with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // pre-create the target so REPLACE TABLE (not CREATE OR REPLACE) is valid + sql(s"CREATE TABLE $newTableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + + val (txn, txnTables) = executeTransactionMultiQE { + sql(s"""REPLACE TABLE $newTableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + + val sourceTxnTable = txnTables(tableNameAsString) + assert(sourceTxnTable.scanEvents.size >= 1) + + checkAnswer( + sql(s"SELECT * FROM $newTableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(3, 300, "hr"))) + } + + test("RTAS self-reference with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // source and target are the same table: reads the old snapshot via TxnTable, + // replaces the table with a filtered version + val (txn, txnTables) = executeTransactionMultiQE { + sql(s"""CREATE OR REPLACE TABLE $tableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + + val sourceTxnTable = txnTables(tableNameAsString) + assert(sourceTxnTable.scanEvents.size >= 1) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(3, 300, "hr"))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 73983283bcb8b..23c6a20fdfa7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -145,7 +145,18 @@ abstract class RowLevelOperationSuiteBase (catalog.lastTransaction, indexByName(tables)) } - private def indexByName[T <: Table](tables: Seq[T]): Map[String, T] = { + protected def executeTransactionMultiQE(func: => Unit): (Txn, Map[String, TxnTable]) = { + val qes = withQueryExecutionsCaptured(spark)(func) + val tables = qes.flatMap { qe => + collectWithSubqueries(qe.executedPlan) { + case BatchScanExec(_, _, _, _, table: TxnTable, _) => table + case BatchScanExec(_, _, _, _, RowLevelOperationTable(table: TxnTable, _), _) => table + } + } + (catalog.lastTransaction, indexByName(tables)) + } + + protected def indexByName[T <: Table](tables: Seq[T]): Map[String, T] = { tables.groupBy(_.name).map { case (name, sameNameTables) => val Seq(table) = sameNameTables.distinct From 6b5b914ad478707a280dfe823a6307c7a3fb66e9 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Tue, 7 Apr 2026 12:29:28 +0000 Subject: [PATCH 05/22] Fix comp error + refactor executeTransaction --- .../spark/sql/execution/QueryExecution.scala | 4 ++-- .../sql/connector/CTASTransactionSuite.scala | 8 +++---- .../RowLevelOperationSuiteBase.scala | 23 +++---------------- 3 files changed, 9 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 638419f900b34..bd7731a96a3ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -284,8 +284,8 @@ class QueryExecution( // During a transaction, skip cache substitution. This is to avoid replacing relations // loaded by the transactional catalog with potentially stale relations cached before // the transaction was active. - val plan = if (transactionOpt.isDefined) { - plan + if (transactionOpt.isDefined) { + normalized } else { // clone the plan to avoid sharing the plan instance between different stages like diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala index 1643fa6879525..ff055039173c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala @@ -30,7 +30,7 @@ class CTASTransactionSuite extends RowLevelOperationSuiteBase { |{ "pk": 2, "salary": 200, "dep": "software" } |""".stripMargin) - val (txn, txnTables) = executeTransactionMultiQE { + val (txn, txnTables) = executeTransaction { sql(s"""CREATE TABLE $newTableNameAsString |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' |""".stripMargin) @@ -58,7 +58,7 @@ class CTASTransactionSuite extends RowLevelOperationSuiteBase { spark.catalog.cacheTable(tableNameAsString) try { - val (txn, txnTables) = executeTransactionMultiQE { + val (txn, txnTables) = executeTransaction { sql(s"""CREATE TABLE $newTableNameAsString |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' |""".stripMargin) @@ -90,7 +90,7 @@ class CTASTransactionSuite extends RowLevelOperationSuiteBase { // pre-create the target so REPLACE TABLE (not CREATE OR REPLACE) is valid sql(s"CREATE TABLE $newTableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") - val (txn, txnTables) = executeTransactionMultiQE { + val (txn, txnTables) = executeTransaction { sql(s"""REPLACE TABLE $newTableNameAsString |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' |""".stripMargin) @@ -119,7 +119,7 @@ class CTASTransactionSuite extends RowLevelOperationSuiteBase { // source and target are the same table: reads the old snapshot via TxnTable, // replaces the table with a filtered version - val (txn, txnTables) = executeTransactionMultiQE { + val (txn, txnTables) = executeTransaction { sql(s"""CREATE OR REPLACE TABLE $tableNameAsString |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' |""".stripMargin) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 23c6a20fdfa7b..a8c9a0bc4ab00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Id import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.write.RowLevelOperationTable -import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.internal.SQLConf @@ -135,19 +135,7 @@ abstract class RowLevelOperationSuiteBase } protected def executeTransaction(func: => Unit): (Txn, Map[String, TxnTable]) = { - val qe = execute(func) - val tables = collectWithSubqueries(qe.executedPlan) { - case BatchScanExec(_, _, _, _, table: TxnTable, _) => - table - case BatchScanExec(_, _, _, _, RowLevelOperationTable(table: TxnTable, _), _) => - table - } - (catalog.lastTransaction, indexByName(tables)) - } - - protected def executeTransactionMultiQE(func: => Unit): (Txn, Map[String, TxnTable]) = { - val qes = withQueryExecutionsCaptured(spark)(func) - val tables = qes.flatMap { qe => + val tables = withQueryExecutionsCaptured(spark)(func).flatMap { qe => collectWithSubqueries(qe.executedPlan) { case BatchScanExec(_, _, _, _, table: TxnTable, _) => table case BatchScanExec(_, _, _, _, RowLevelOperationTable(table: TxnTable, _), _) => table @@ -166,13 +154,8 @@ abstract class RowLevelOperationSuiteBase // executes an operation and keeps the executed plan protected def executeAndKeepPlan(func: => Unit): SparkPlan = { - val qe = execute(func) - stripAQEPlan(qe.executedPlan) - } - - private def execute(func: => Unit): QueryExecution = { withQueryExecutionsCaptured(spark)(func) match { - case Seq(qe) => qe + case Seq(qe) => stripAQEPlan(qe.executedPlan) case other => fail(s"expected only one query execution, but got ${other.size}") } } From 5c1c7215101898997a6774adaf4f36d81d3ce626 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 8 Apr 2026 08:22:35 +0000 Subject: [PATCH 06/22] Test improvements --- .../spark/sql/connector/catalog/txns.scala | 6 ++-- .../AppendDataTransactionSuite.scala | 7 ++++ ...e.scala => CTASRTASTransactionSuite.scala} | 5 ++- .../connector/DeleteFromTableSuiteBase.scala | 4 +++ .../connector/MergeIntoDataFrameSuite.scala | 36 +++++++++++++++++-- .../connector/MergeIntoTableSuiteBase.scala | 6 ++++ .../RowLevelOperationSuiteBase.scala | 16 ++++----- .../sql/connector/UpdateTableSuiteBase.scala | 10 ++++-- 8 files changed, 74 insertions(+), 16 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/connector/{CTASTransactionSuite.scala => CTASRTASTransactionSuite.scala} (96%) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index e76fdf97c6888..c6339d2099b63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -53,8 +53,10 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction { } override def close(): Unit = { - catalog.clearActiveTransaction() - this.closed = true + if (!closed) { + catalog.clearActiveTransaction() + this.closed = true + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala index 1c8e7fc5a0fd4..379cf6df0f739 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala @@ -46,6 +46,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 1) + assert(table.version() == "2") // check the source scan was tracked via the transaction catalog val targetTxnTable = txnTables(tableNameAsString) @@ -75,6 +76,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { // check txn was properly committed and closed assert(txn.currentState == Committed) assert(txn.isClosed) + assert(table.version() == "2") // check data was inserted correctly checkAnswer( @@ -104,6 +106,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) + assert(table.version() == "2") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -131,6 +134,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) + assert(table.version() == "2") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -159,6 +163,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) + assert(table.version() == "2") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -186,6 +191,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) + assert(table.version() == "2") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -214,6 +220,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 1) + assert(table.version() == "2") // check data was inserted correctly checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala index ff055039173c8..c58a78498f9f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.Row import org.apache.spark.sql.connector.catalog.Committed -class CTASTransactionSuite extends RowLevelOperationSuiteBase { +class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase { private val newTableNameAsString = "cat.ns1.new_table" @@ -39,6 +39,7 @@ class CTASTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 1) + assert(table.version() == "2") val sourceTxnTable = txnTables(tableNameAsString) assert(sourceTxnTable.scanEvents.size >= 1) @@ -66,6 +67,7 @@ class CTASTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) + assert(table.version() == "2") // cache miss: TxnTable-based relation is not structurally equal to the cached one, // so the scan goes through the transaction catalog and scan events are captured @@ -99,6 +101,7 @@ class CTASTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 1) + assert(table.version() == "2") val sourceTxnTable = txnTables(tableNameAsString) assert(sourceTxnTable.scanEvents.size >= 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index 3d188af4701a8..378a3af2561b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -694,6 +694,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) + assert(table.version() == "2") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -720,6 +721,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") val sourceTxnTable = txnTables(sourceNameAsString) val expectedNumSourceScans = if (deltaDelete) 1 else 2 @@ -765,6 +767,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") val targetTxnTable = txnTables(tableNameAsString) val expectedNumTargetScans = if (deltaDelete) 1 else 2 @@ -810,6 +813,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") val targetTxnTable = txnTables(tableNameAsString) val expectedNumTargetScans = if (deltaDelete) 1 else 2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala index 687aae91438da..d58e22e63d71e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.{sources, Column, Row} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.classic.MergeIntoWriter -import org.apache.spark.sql.connector.catalog.Committed +import org.apache.spark.sql.connector.catalog.{Aborted, Committed} import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.TableInfo import org.apache.spark.sql.functions._ @@ -66,6 +66,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 1) + assert(table.version() == "2") // check all table scans val targetTxnTable = txnTables(tableNameAsString) @@ -92,8 +93,39 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { Row(1, 101, "hr"), // update Row(2, 200, "software"), // unchanged Row(3, 300, "hr"))) // unchanged + } + + for (alterClause <- Seq( + "ADD COLUMN new_col INT", + "DROP COLUMN salary", + "ALTER COLUMN salary TYPE BIGINT", + "ALTER COLUMN pk DROP NOT NULL")) + test(s"self merge fails when source schema changes after analysis - DDL: $alterClause" ) { + withTable(tableNameAsString) { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = spark.table(tableNameAsString).where("salary == 100").as("source") + sourceDF.queryExecution.assertAnalyzed() - // TODO Achatzis check version. + sql(s"ALTER TABLE $tableNameAsString $alterClause") + + val e = intercept[AnalysisException] { + sourceDF + .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk")) + .whenMatched() + .update(Map("salary" -> targetTableCol("salary").plus(1))) + .merge() + } + + assert( + e.getCondition == "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", + alterClause) + assert(catalog.lastTransaction.currentState == Aborted, alterClause) + assert(catalog.lastTransaction.isClosed, alterClause) + } } test("merge into empty table with NOT MATCHED clause") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 9889b9b53742d..0423fe66a2067 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -64,6 +64,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 1) + assert(table.version() == "2") // check all table scans val targetTxnTable = txnTables(tableNameAsString) @@ -163,6 +164,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check target table was scanned correctly val targetTxnTable = txnTables(tableNameAsString) @@ -249,6 +251,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check target table was scanned correctly val targetTxnTable = txnTables(tableNameAsString) @@ -318,6 +321,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check target table was scanned correctly val targetTxnTable = txnTables(tableNameAsString) @@ -1007,6 +1011,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check target table was scanned correctly val targetTxnTable = txnTables(tableNameAsString) @@ -1078,6 +1083,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(txn.isClosed) // both target and source must have been read through the transaction catalog assert(txnTables.size == 2) + assert(table.version() == "2") assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) assert(txnTables(tableNameAsString).scanEvents.nonEmpty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index a8c9a0bc4ab00..565e209b021a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -134,6 +134,14 @@ abstract class RowLevelOperationSuiteBase } } + // executes an operation and keeps the executed plan + protected def executeAndKeepPlan(func: => Unit): SparkPlan = { + withQueryExecutionsCaptured(spark)(func) match { + case Seq(qe) => stripAQEPlan(qe.executedPlan) + case other => fail(s"expected only one query execution, but got ${other.size}") + } + } + protected def executeTransaction(func: => Unit): (Txn, Map[String, TxnTable]) = { val tables = withQueryExecutionsCaptured(spark)(func).flatMap { qe => collectWithSubqueries(qe.executedPlan) { @@ -152,14 +160,6 @@ abstract class RowLevelOperationSuiteBase } } - // executes an operation and keeps the executed plan - protected def executeAndKeepPlan(func: => Unit): SparkPlan = { - withQueryExecutionsCaptured(spark)(func) match { - case Seq(qe) => stripAQEPlan(qe.executedPlan) - case other => fail(s"expected only one query execution, but got ${other.size}") - } - } - // executes an operation and extracts conditions from ReplaceData or WriteDelta protected def executeAndKeepConditions(func: => Unit): (Expression, Option[Expression]) = { val Seq(qe) = withQueryExecutionsCaptured(spark)(func) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index 0061856ea5da8..2ca4235f96b55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -811,6 +811,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check target table was scanned correctly val targetTxnTable = txnTables(tableNameAsString) @@ -870,6 +871,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check source table was scanned correctly (dep = 'hr' filter in the subquery) val sourceTxnTable = txnTables(sourceNameAsString) @@ -923,6 +925,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check source table was scanned via the transaction catalog val sourceTxnTable = txnTables(sourceNameAsString) @@ -1003,6 +1006,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check target table was scanned correctly val targetTxnTable = txnTables(tableNameAsString) @@ -1046,15 +1050,15 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { |{ "pk": 2, "salary": 200, "dep": "software" } |""".stripMargin) - // NOTE: df.explain() on a DML command actually executes the write. - // TODO(achatzis): This is existing behavior but check why this is OK. Shouldn't sql() be lazy? + // sql() is lazy, but explain() forces executedPlan. sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'").explain() assert(catalog.lastTransaction != null) assert(catalog.lastTransaction.currentState == Committed) assert(catalog.lastTransaction.isClosed) + assert(table.version() == "2") - // the UPDATE was actually executed, not just planned + // The UPDATE was actually executed, not just planned. checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Seq( From d5ea3474d5416cc2b2cf36daa39628e471e1d7d0 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 8 Apr 2026 09:32:58 +0000 Subject: [PATCH 07/22] Append suite improvements pass 1 --- .../AppendDataTransactionSuite.scala | 211 ++++++++++++++---- 1 file changed, 172 insertions(+), 39 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala index 379cf6df0f739..22cdcd1cdfad2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.connector +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.Row -import org.apache.spark.sql.connector.catalog.Committed +import org.apache.spark.sql.connector.catalog.{Aborted, Committed} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode +import org.apache.spark.sql.sources class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { @@ -88,7 +91,8 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { Row(4, 400, "finance"))) } - test("SQL INSERT OVERWRITE with transactional checks") { + for (isDynamic <- Seq(false, true)) + test(s"SQL INSERT OVERWRITE with transactional checks - isDynamic: $isDynamic") { // create table with initial data; table is partitioned by dep createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", """{ "pk": 1, "salary": 100, "dep": "hr" } @@ -96,12 +100,22 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { |{ "pk": 3, "salary": 300, "dep": "hr" } |""".stripMargin) - // INSERT OVERWRITE with static partition predicate -> OverwriteByExpression - val (txn, _) = executeTransaction { - sql(s"""INSERT OVERWRITE $tableNameAsString - |PARTITION (dep = 'hr') - |SELECT pk + 10, salary FROM $tableNameAsString WHERE dep = 'hr' - |""".stripMargin) + val insertOverwrite = if (isDynamic) { + // OverwritePartitionsDynamic + s"""INSERT OVERWRITE $tableNameAsString + |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } else { + // OverwriteByExpression + s"""INSERT OVERWRITE $tableNameAsString + |PARTITION (dep = 'hr') + |SELECT pk + 10, salary FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } + + val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC + val (txn, _) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { + executeTransaction { sql(insertOverwrite) } } assert(txn.currentState == Committed) @@ -116,35 +130,6 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { Row(13, 300, "hr"))) // overwritten } - test("SQL INSERT OVERWRITE dynamic partition with transactional checks") { - // create table with initial data; table is partitioned by dep - createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", - """{ "pk": 1, "salary": 100, "dep": "hr" } - |{ "pk": 2, "salary": 200, "dep": "software" } - |{ "pk": 3, "salary": 300, "dep": "hr" } - |""".stripMargin) - - // INSERT OVERWRITE with dynamic partitioning -> OverwritePartitionsDynamic - withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> "dynamic") { - val (txn, _) = executeTransaction { - sql(s"""INSERT OVERWRITE $tableNameAsString - |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr' - |""".stripMargin) - } - - assert(txn.currentState == Committed) - assert(txn.isClosed) - assert(table.version() == "2") - - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(2, 200, "software"), // unchanged (different partition) - Row(11, 100, "hr"), // overwrote hr partition - Row(13, 300, "hr"))) // overwrote hr partition - } - } - test("writeTo overwrite with transactional checks") { // create table with initial data; table is partitioned by dep createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -168,7 +153,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Seq( - Row(2, 200, "software"), // unchanged (different partition) + Row(2, 200, "software"), // unchanged Row(11, 999, "hr"), // overwrote hr partition Row(12, 888, "hr"))) // overwrote hr partition } @@ -196,7 +181,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Seq( - Row(2, 200, "software"), // unchanged (different partition) + Row(2, 200, "software"), // unchanged Row(11, 999, "hr"), // overwrote hr partition Row(12, 888, "hr"))) // overwrote hr partition } @@ -232,4 +217,152 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { Row(11, 100, "hr"), // inserted from pk=1 Row(13, 300, "hr"))) // inserted from pk=3 } + + test("SQL INSERT INTO SELECT with subquery on source table and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 500, 'hr'), (3, 600, 'software')") + + // INSERT using a subquery that reads from the target to filter source rows + // both tables are scanned through the transaction catalog + val (txn, txnTables) = executeTransaction { + sql( + s"""INSERT INTO $tableNameAsString + |SELECT pk + 10, salary, dep FROM $sourceNameAsString + |WHERE pk IN (SELECT pk FROM $tableNameAsString WHERE dep = 'hr') + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // target was scanned via the transaction catalog (IN subquery) + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + // source was scanned via the transaction catalog + assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(11, 500, "hr"))) // inserted: source pk=1 matched target hr row + } + + test("SQL INSERT INTO SELECT with CTE and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 500, 'hr'), (3, 600, 'software')") + + // CTE reads from target; INSERT selects from source filtered by the CTE result + // both tables are scanned through the transaction catalog + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH hr_pks AS (SELECT pk FROM $tableNameAsString WHERE dep = 'hr') + |INSERT INTO $tableNameAsString + |SELECT pk + 10, salary, dep FROM $sourceNameAsString + |WHERE pk IN (SELECT pk FROM hr_pks) + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // target was scanned via the transaction catalog (CTE) + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + // source was scanned via the transaction catalog + assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(11, 500, "hr"))) // inserted: source pk=1 matched target hr row via CTE + } + + test("SQL INSERT with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val e = intercept[AnalysisException] { + sql(s"INSERT INTO $tableNameAsString SELECT nonexistent_col FROM $tableNameAsString") + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + for (isDynamic <- Seq(false, true)) + test(s"SQL INSERT OVERWRITE with analysis failure and transactional checks" + + s"isDynamic: $isDynamic") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val insertOverwrite = if (isDynamic) { + s"""INSERT OVERWRITE $tableNameAsString + |SELECT nonexistent_col, salary, dep FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } else { + s"""INSERT OVERWRITE $tableNameAsString + |PARTITION (dep = 'hr') + |SELECT nonexistent_col FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } + + val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC + val e = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { + intercept[AnalysisException] { sql(insertOverwrite) } + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("EXPLAIN INSERT SQL with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"EXPLAIN INSERT INTO $tableNameAsString VALUES (3, 300, 'hr')") + + // EXPLAIN should not start a transaction + assert(catalog.transaction === null) + + // INSERT was not executed; data is unchanged + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) + } } From 8d52c3962f673e78097e0d9d39d8803448936698 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 8 Apr 2026 13:35:25 +0000 Subject: [PATCH 08/22] Append suite improvements pass 2 --- .../AppendDataTransactionSuite.scala | 100 ++++++++++++------ 1 file changed, 67 insertions(+), 33 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala index 22cdcd1cdfad2..93da0a95de956 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala @@ -46,14 +46,18 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { } // check txn was properly committed and closed - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(txnTables.size == 1) - assert(table.version() == "2") + assert(txnTables.size === 1) + assert(table.version() === "2") // check the source scan was tracked via the transaction catalog val targetTxnTable = txnTables(tableNameAsString) - assert(targetTxnTable.scanEvents.size >= 1) + assert(targetTxnTable.scanEvents.size === 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("pk", 1) => true + case _ => false + }) // check data was appended correctly checkAnswer( @@ -72,14 +76,17 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { |""".stripMargin) // SQL INSERT INTO using VALUES - val (txn, _) = executeTransaction { + val (txn, txnTables) = executeTransaction { sql(s"INSERT INTO $tableNameAsString VALUES (3, 300, 'hr'), (4, 400, 'finance')") } // check txn was properly committed and closed - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(table.version() == "2") + assert(table.version() === "2") + + // VALUES literal - No catalog tables were scanned + assert(txnTables.isEmpty) // check data was inserted correctly checkAnswer( @@ -114,13 +121,22 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { } val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC - val (txn, _) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { + val (txn, txnTables) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { executeTransaction { sql(insertOverwrite) } } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(table.version() == "2") + assert(table.version() === "2") + + // the SELECT reads from the target table once with a dep='hr' filter + assert(txnTables.size == 1) + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size == 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -142,13 +158,16 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))). toDF("pk", "salary", "dep") - val (txn, _) = executeTransaction { + val (txn, txnTables) = executeTransaction { sourceDF.writeTo(tableNameAsString).overwrite(col("dep") === "hr") } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(table.version() == "2") + assert(table.version() === "2") + + // literal DataFrame source - no catalog tables were scanned + assert(txnTables.isEmpty) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -170,13 +189,16 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))). toDF("pk", "salary", "dep") - val (txn, _) = executeTransaction { + val (txn, txnTables) = executeTransaction { sourceDF.writeTo(tableNameAsString).overwritePartitions() } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(table.version() == "2") + assert(table.version() === "2") + + // literal DataFrame source - no catalog tables were scanned + assert(txnTables.isEmpty) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -202,10 +224,18 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { } // check txn was properly committed and closed - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(txnTables.size == 1) - assert(table.version() == "2") + assert(table.version() === "2") + + // the SELECT reads from the target table once with a dep='hr' filter + assert(txnTables.size === 1) + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) // check data was inserted correctly checkAnswer( @@ -237,20 +267,22 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { |""".stripMargin) } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(txnTables.size == 2) - assert(table.version() == "2") + assert(txnTables.size === 2) + assert(table.version() === "2") - // target was scanned via the transaction catalog (IN subquery) + // target was scanned via the transaction catalog (IN subquery) once with dep='hr' filter val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) assert(targetTxnTable.scanEvents.flatten.exists { case sources.EqualTo("dep", "hr") => true case _ => false }) - // source was scanned via the transaction catalog - assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) + // source was scanned via the transaction catalog exactly once (no filter) + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -280,20 +312,22 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { |""".stripMargin) } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(txnTables.size == 2) - assert(table.version() == "2") + assert(txnTables.size === 2) + assert(table.version() === "2") - // target was scanned via the transaction catalog (CTE) + // target was scanned via the transaction catalog (CTE) once with dep='hr' filter val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) assert(targetTxnTable.scanEvents.flatten.exists { case sources.EqualTo("dep", "hr") => true case _ => false }) - // source was scanned via the transaction catalog - assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) + // source was scanned via the transaction catalog exactly once (no filter) + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -314,7 +348,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { } assert(e.getMessage.contains("nonexistent_col")) - assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.currentState === Aborted) assert(catalog.lastTransaction.isClosed) } @@ -343,7 +377,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { } assert(e.getMessage.contains("nonexistent_col")) - assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.currentState === Aborted) assert(catalog.lastTransaction.isClosed) } From fd1a3097de9ed63bb3d297b97756142207b9ba3b Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Thu, 9 Apr 2026 08:52:04 +0000 Subject: [PATCH 09/22] RTAS/CTAS improvements --- .../spark/sql/connector/catalog/txns.scala | 20 ++- .../v2/WriteToDataSourceV2Exec.scala | 25 +--- .../connector/CTASRTASTransactionSuite.scala | 138 +++++++++++++----- 3 files changed, 120 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index c6339d2099b63..0881daeb7635d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -74,6 +74,8 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) // TODO: Revise schema evolution. alterTableWithData(delegate.data, schema) + private val initialVersion: String = version() + // A tracker of filters used in each scan. val scanEvents = new ArrayBuffer[Array[Filter]]() @@ -82,14 +84,16 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) } def commit(): Unit = { - delegate.dataMap.clear() - // TODO: Revise schema evolution. - delegate.alterTableWithData(data, delegate.schema) - delegate.replacedPartitions = replacedPartitions - delegate.lastWriteInfo = lastWriteInfo - delegate.lastWriteLog = lastWriteLog - delegate.commits ++= commits - delegate.increaseVersion() + if (version() != initialVersion) { + delegate.dataMap.clear() + // TODO: Revise schema evolution. + delegate.alterTableWithData(data, delegate.schema) + delegate.replacedPartitions = replacedPartitions + delegate.lastWriteInfo = lastWriteInfo + delegate.lastWriteLog = lastWriteLog + delegate.commits ++= commits + delegate.increaseVersion() + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index b860dc3347f9f..07876887788be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -120,13 +120,8 @@ case class AtomicCreateTableAsSelectExec( query: LogicalPlan, tableSpec: TableSpec, writeOptions: Map[String, String], - ifNotExists: Boolean, - transaction: Option[Transaction] = None) - extends V2CreateTableAsSelectBaseExec - with TransactionalExec { - - override def withTransaction(txn: Option[Transaction]): AtomicCreateTableAsSelectExec = - copy(transaction = txn) + ifNotExists: Boolean) + extends V2CreateTableAsSelectBaseExec { val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -147,9 +142,7 @@ case class AtomicCreateTableAsSelectExec( .build() val stagedTable = Option(catalog.stageCreate(ident, tableInfo) ).getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - val result = writeToTable(catalog, stagedTable, writeOptions, ident, query, overwrite = false) - transaction.foreach(TransactionUtils.commit) - result + writeToTable(catalog, stagedTable, writeOptions, ident, query, overwrite = false) } } @@ -240,12 +233,8 @@ case class AtomicReplaceTableAsSelectExec( tableSpec: TableSpec, writeOptions: Map[String, String], orCreate: Boolean, - invalidateCache: (TableCatalog, Identifier) => Unit, - transaction: Option[Transaction] = None) - extends V2CreateTableAsSelectBaseExec with TransactionalExec { - - override def withTransaction(txn: Option[Transaction]): AtomicReplaceTableAsSelectExec = - copy(transaction = txn) + invalidateCache: (TableCatalog, Identifier) => Unit) + extends V2CreateTableAsSelectBaseExec { val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -286,9 +275,7 @@ case class AtomicReplaceTableAsSelectExec( } val table = Option(staged).getOrElse( catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - val result = writeToTable(catalog, table, writeOptions, ident, query, overwrite = true) - transaction.foreach(TransactionUtils.commit) - result + writeToTable(catalog, table, writeOptions, ident, query, overwrite = true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala index c58a78498f9f4..8acdd8242ef1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala @@ -17,13 +17,19 @@ package org.apache.spark.sql.connector +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.Row -import org.apache.spark.sql.connector.catalog.Committed +import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTable} +import org.apache.spark.sql.sources class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase { private val newTableNameAsString = "cat.ns1.new_table" + private def newTable: InMemoryRowLevelOperationTable = + catalog.loadTable(Identifier.of(Array("ns1"), "new_table")) + .asInstanceOf[InMemoryRowLevelOperationTable] + test("CTAS with transactional checks") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", """{ "pk": 1, "salary": 100, "dep": "hr" } @@ -36,50 +42,56 @@ class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase { |""".stripMargin) } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(txnTables.size == 1) - assert(table.version() == "2") + assert(txnTables.size === 1) + assert(table.version() === "1") // source table: read-only, version unchanged + assert(newTable.version() === "1") // target table: newly created and written + // the source table was scanned once through the transaction catalog with a dep='hr' filter val sourceTxnTable = txnTables(tableNameAsString) - assert(sourceTxnTable.scanEvents.size >= 1) + assert(sourceTxnTable.scanEvents.size === 1) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + checkAnswer( + sql(s"SELECT * FROM $newTableNameAsString"), + Seq(Row(1, 100, "hr"))) + } + + test("CTAS from literal source with transactional checks") { + // no source catalog table involved — the query is a pure literal SELECT + val (txn, txnTables) = executeTransaction { + sql(s"CREATE TABLE $newTableNameAsString AS SELECT 1 AS pk, 100 AS salary, 'hr' AS dep") + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // literal SELECT - no catalog tables were scanned + assert(txnTables.isEmpty) + assert(newTable.version() === "1") // target table: newly created and written checkAnswer( sql(s"SELECT * FROM $newTableNameAsString"), Seq(Row(1, 100, "hr"))) } - test("CTAS with cached source and transactional checks") { + test("CTAS with analysis failure and transactional checks") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", """{ "pk": 1, "salary": 100, "dep": "hr" } |{ "pk": 2, "salary": 200, "dep": "software" } |""".stripMargin) - // cache the source table before running CTAS - spark.catalog.cacheTable(tableNameAsString) - - try { - val (txn, txnTables) = executeTransaction { - sql(s"""CREATE TABLE $newTableNameAsString - |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' - |""".stripMargin) - } - - assert(txn.currentState == Committed) - assert(txn.isClosed) - assert(table.version() == "2") - - // cache miss: TxnTable-based relation is not structurally equal to the cached one, - // so the scan goes through the transaction catalog and scan events are captured - val sourceTxnTable = txnTables(tableNameAsString) - assert(sourceTxnTable.scanEvents.size >= 1) - - checkAnswer( - sql(s"SELECT * FROM $newTableNameAsString"), - Seq(Row(1, 100, "hr"))) - } finally { - spark.catalog.uncacheTable(tableNameAsString) + val e = intercept[AnalysisException] { + sql(s"CREATE TABLE $newTableNameAsString AS SELECT nonexistent_col FROM $tableNameAsString") } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) } test("RTAS with transactional checks") { @@ -98,13 +110,19 @@ class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase { |""".stripMargin) } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(txnTables.size == 1) - assert(table.version() == "2") + assert(txnTables.size === 1) + assert(table.version() === "1") // source table: read-only, version unchanged + assert(newTable.version() === "1") // target table: replaced and written + // the source table was scanned once through the transaction catalog with a dep='hr' filter val sourceTxnTable = txnTables(tableNameAsString) - assert(sourceTxnTable.scanEvents.size >= 1) + assert(sourceTxnTable.scanEvents.size === 1) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) checkAnswer( sql(s"SELECT * FROM $newTableNameAsString"), @@ -128,11 +146,18 @@ class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase { |""".stripMargin) } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) + assert(txnTables.size === 1) + assert(table.version() === "1") // source/target table: replaced in place, version reset to 1 + // the source/target table was scanned once with a dep='hr' filter val sourceTxnTable = txnTables(tableNameAsString) - assert(sourceTxnTable.scanEvents.size >= 1) + assert(sourceTxnTable.scanEvents.size === 1) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -140,4 +165,45 @@ class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase { Row(1, 100, "hr"), Row(3, 300, "hr"))) } + + test("RTAS with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val e = intercept[AnalysisException] { + sql(s"""CREATE OR REPLACE TABLE $tableNameAsString + |AS SELECT nonexistent_col FROM $tableNameAsString + |""".stripMargin) + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("simple CREATE TABLE and DROP TABLE do not create transactions") { + sql(s"CREATE TABLE $newTableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + assert(catalog.transaction === null) + assert(catalog.lastTransaction === null) + + sql(s"DROP TABLE $newTableNameAsString") + assert(catalog.transaction === null) + assert(catalog.lastTransaction === null) + } + + test("EXPLAIN CTAS with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"""EXPLAIN CREATE TABLE $newTableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + + // EXPLAIN should not start a transaction + assert(catalog.transaction === null) + } } From 3f45e22e72e23f720b3806708f42e3244fa685fe Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Thu, 9 Apr 2026 14:07:14 +0000 Subject: [PATCH 10/22] Schema evolution --- .../connector/catalog/InMemoryBaseTable.scala | 4 + .../spark/sql/connector/catalog/txns.scala | 12 ++- .../AppendDataTransactionSuite.scala | 98 +++++++++++++++++++ 3 files changed, 110 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 422632d074a3b..af46ac006f3c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -76,6 +76,10 @@ abstract class InMemoryBaseTable( override def columns(): Array[Column] = tableColumns + private[catalog] def updateColumns(newColumns: Array[Column]): Unit = { + tableColumns = newColumns + } + override def version(): String = tableVersion.toString def setVersion(version: String): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 0881daeb7635d..6a6897c896165 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -71,9 +71,9 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) delegate.properties, delegate.constraints) { - // TODO: Revise schema evolution. - alterTableWithData(delegate.data, schema) + withData(delegate.data) + // Keep initial version to detect any changes during the transaction. private val initialVersion: String = version() // A tracker of filters used in each scan. @@ -86,8 +86,8 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) def commit(): Unit = { if (version() != initialVersion) { delegate.dataMap.clear() - // TODO: Revise schema evolution. - delegate.alterTableWithData(data, delegate.schema) + delegate.alterTableWithData(data, schema) + delegate.updateColumns(columns()) // Evolve schema if needed. delegate.replacedPartitions = replacedPartitions delegate.lastWriteInfo = lastWriteInfo delegate.lastWriteLog = lastWriteLog @@ -124,6 +124,10 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T } override def alterTable(ident: Identifier, changes: TableChange*): Table = { + // TODO: This evicts the staged TxnTable, losing any in-flight DML changes. The correct + // approach is to apply only the schema change to the existing TxnTable so that the ongoing + // DML can observe the new schema and reconcile at commit time. Concurrent DDL + DML is not + // supported in this test catalog for now. val newDelegateTable = delegate.alterTable(ident, changes: _*) tables.remove(ident) // Load again. newDelegateTable diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala index 93da0a95de956..aef9c65550fc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala @@ -399,4 +399,102 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { Row(1, 100, "hr"), Row(2, 200, "software"))) } + + test("SQL INSERT WITH SCHEMA EVOLUTION adds new column with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql( + s"""CREATE TABLE $sourceNameAsString + |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin) + sql(s"INSERT INTO $sourceNameAsString VALUES (3, 300, 'hr', true), (4, 400, 'software', false)") + + val (txn, txnTables) = executeTransaction { + sql(s"INSERT WITH SCHEMA EVOLUTION INTO $tableNameAsString SELECT * FROM $sourceNameAsString") + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // the new column must be visible in the committed delegate's schema + assert(table.schema.fieldNames.toSeq === Seq("pk", "salary", "dep", "active")) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), // pre-existing rows: active is null + Row(2, 200, "software", null), + Row(3, 300, "hr", true), // inserted with active + Row(4, 400, "software", false))) + } + + for (isDynamic <- Seq(false, true)) + test(s"SQL INSERT OVERWRITE WITH SCHEMA EVOLUTION adds new column with transactional checks " + + s"isDynamic: $isDynamic") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql( + s"""CREATE TABLE $sourceNameAsString + |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin) + sql(s"INSERT INTO $sourceNameAsString VALUES (11, 999, 'hr', true), (12, 888, 'hr', false)") + + val insertOverwrite = if (isDynamic) { + s"""INSERT WITH SCHEMA EVOLUTION OVERWRITE TABLE $tableNameAsString + |SELECT * FROM $sourceNameAsString + |""".stripMargin + } else { + s"""INSERT WITH SCHEMA EVOLUTION OVERWRITE TABLE $tableNameAsString + |PARTITION (dep = 'hr') + |SELECT pk, salary, active FROM $sourceNameAsString + |""".stripMargin + } + + val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC + val (txn, _) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { + executeTransaction { sql(insertOverwrite) } + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // the new column must be visible in the committed delegate's schema + assert(table.schema.fieldNames.contains("active")) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software", null), // unchanged (different partition) + Row(11, 999, "hr", true), // overwrote hr partition + Row(12, 888, "hr", false))) + } + + test("SQL INSERT WITH SCHEMA EVOLUTION analysis failure aborts transaction") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql( + s"""CREATE TABLE $sourceNameAsString + |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin) + + val e = intercept[AnalysisException] { + sql( + s"""INSERT WITH SCHEMA EVOLUTION INTO $tableNameAsString + |SELECT nonexistent_col FROM $sourceNameAsString + |""".stripMargin) + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) + // schema must be unchanged after the aborted transaction + assert(table.schema.fieldNames.toSeq === Seq("pk", "salary", "dep")) + } } From 498b4cd2467077b3f9d884da8de5fc7b8caeaa2b Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Fri, 10 Apr 2026 08:23:39 +0000 Subject: [PATCH 11/22] Fix schema evolution --- .../scala/org/apache/spark/sql/connector/catalog/txns.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 6a6897c896165..ba6de6be9130a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -71,7 +71,7 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) delegate.properties, delegate.constraints) { - withData(delegate.data) + alterTableWithData(delegate.data, delegate.schema) // Keep initial version to detect any changes during the transaction. private val initialVersion: String = version() From 39fe3309e4191b38a69a00a0f5eae5442d34c3cc Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Fri, 10 Apr 2026 14:48:29 +0000 Subject: [PATCH 12/22] Add schema evolution fixme --- .../org/apache/spark/sql/connector/catalog/txns.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index ba6de6be9130a..94c3bda826b32 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -124,10 +124,10 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T } override def alterTable(ident: Identifier, changes: TableChange*): Table = { - // TODO: This evicts the staged TxnTable, losing any in-flight DML changes. The correct - // approach is to apply only the schema change to the existing TxnTable so that the ongoing - // DML can observe the new schema and reconcile at commit time. Concurrent DDL + DML is not - // supported in this test catalog for now. + // FIXME: This is not transactional. The schema changes are applied directly to the delegate. + // The correct behavior is to apply the schema changes to the TxnTable and propagate them + // to the delegate only after commit. + // Furthermore, this also evicts the staged TxnTable, losing any in-flight DML changes. val newDelegateTable = delegate.alterTable(ident, changes: _*) tables.remove(ident) // Load again. newDelegateTable From 4176e3efe2baefe661f5a671be991a7dd6520f43 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Fri, 10 Apr 2026 15:27:09 +0000 Subject: [PATCH 13/22] Schema evolution fix 2 --- .../spark/sql/connector/catalog/txns.scala | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 94c3bda826b32..4d11945a06580 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap sealed trait TransactionState @@ -63,15 +64,15 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction { // A special table used in row-level operation transactions. It inherits data // from the base table upon construction and propagates staged transaction state // back after an explicit commit. -class TxnTable(val delegate: InMemoryRowLevelOperationTable) +class TxnTable(val delegate: InMemoryRowLevelOperationTable, schema: StructType) extends InMemoryRowLevelOperationTable( delegate.name, - delegate.schema, + schema, delegate.partitioning, delegate.properties, delegate.constraints) { - alterTableWithData(delegate.data, delegate.schema) + alterTableWithData(delegate.data, schema) // Keep initial version to detect any changes during the transaction. private val initialVersion: String = version() @@ -86,8 +87,8 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) def commit(): Unit = { if (version() != initialVersion) { delegate.dataMap.clear() - delegate.alterTableWithData(data, schema) delegate.updateColumns(columns()) // Evolve schema if needed. + delegate.alterTableWithData(data, schema) delegate.replacedPartitions = replacedPartitions delegate.lastWriteInfo = lastWriteInfo delegate.lastWriteLog = lastWriteLog @@ -114,7 +115,7 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T override def loadTable(ident: Identifier): Table = { tables.computeIfAbsent(ident, _ => { val table = delegate.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] - new TxnTable(table) + new TxnTable(table, table.schema()) }) } @@ -124,13 +125,20 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T } override def alterTable(ident: Identifier, changes: TableChange*): Table = { - // FIXME: This is not transactional. The schema changes are applied directly to the delegate. - // The correct behavior is to apply the schema changes to the TxnTable and propagate them - // to the delegate only after commit. - // Furthermore, this also evicts the staged TxnTable, losing any in-flight DML changes. - val newDelegateTable = delegate.alterTable(ident, changes: _*) - tables.remove(ident) // Load again. - newDelegateTable + // AlterTable may be called by ResolveSchemaEvolution when schema evolution is enabled. Thus, + // it needs to be transactional. The schema changes are only propagated to the delegate at + // commit time. + val txnTable = tables.get(ident) + val schema = CatalogV2Util.applySchemaChanges( + txnTable.schema, changes, tableProvider = Some("in-memory"), statementType = "ALTER TABLE") + + if (schema.fields.isEmpty) { + throw new IllegalArgumentException(s"Cannot drop all fields") + } + + val newTxnTable = new TxnTable(txnTable.delegate, schema) + tables.put(ident, newTxnTable) + newTxnTable } override def dropTable(ident: Identifier): Boolean = { From a3f5574620c90b0edbf9b7049e8736c0d56c6595 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Mon, 13 Apr 2026 12:54:00 +0000 Subject: [PATCH 14/22] Delegate schema computation changes to the underlying catalog --- ...nMemoryRowLevelOperationTableCatalog.scala | 25 +++++++++++++++---- .../spark/sql/connector/catalog/txns.scala | 7 ++++-- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index 78c350b5145a9..bdf19e0e9d355 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.catalog import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} +import org.apache.spark.sql.types.StructType class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog @@ -55,11 +56,7 @@ class InMemoryRowLevelOperationTableCatalog override def alterTable(ident: Identifier, changes: TableChange*): Table = { val table = loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes) - val schema = CatalogV2Util.applySchemaChanges( - table.schema, - changes, - tableProvider = Some("in-memory"), - statementType = "ALTER TABLE") + val schema = computeAlterTableSchema(table.schema, changes.toSeq) val partitioning = CatalogV2Util.applyClusterByChanges(table.partitioning, schema, changes) val constraints = CatalogV2Util.collectConstraintChanges(table, changes) @@ -80,6 +77,16 @@ class InMemoryRowLevelOperationTableCatalog newTable } + + /** + * Computes the schema that would result from applying `changes` to `currentSchema`. + * Overriding this allows subclasses to simulate catalogs that selectively ignore some changes + * (e.g. [[PartialSchemaEvolutionCatalog]]). + */ + def computeAlterTableSchema(currentSchema: StructType, changes: Seq[TableChange]): StructType = { + CatalogV2Util.applySchemaChanges( + currentSchema, changes, tableProvider = Some("in-memory"), statementType = "ALTER TABLE") + } } /** @@ -108,4 +115,12 @@ class PartialSchemaEvolutionCatalog extends InMemoryRowLevelOperationTableCatalo tables.put(ident, newTable) newTable } + + // When used inside a transaction, TxnTableCatalog.alterTable uses this method to compute + // the resulting schema instead of calling CatalogV2Util.applySchemaChanges directly. + // Returning the current schema unchanged mirrors the behaviour of alterTable above (silently + // ignore all column changes), so ResolveSchemaEvolution can still detect pending changes. + override def computeAlterTableSchema( + currentSchema: StructType, + changes: Seq[TableChange]): StructType = currentSchema } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 4d11945a06580..1f0d36f925625 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -128,9 +128,12 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T // AlterTable may be called by ResolveSchemaEvolution when schema evolution is enabled. Thus, // it needs to be transactional. The schema changes are only propagated to the delegate at // commit time. + // + // We delegate schema computation to the underlying catalog so that catalogs that selectively + // ignore some changes (e.g. PartialSchemaEvolutionCatalog) have the same behaviour inside a + // transaction. This lets ResolveSchemaEvolution detect pending changes correctly. val txnTable = tables.get(ident) - val schema = CatalogV2Util.applySchemaChanges( - txnTable.schema, changes, tableProvider = Some("in-memory"), statementType = "ALTER TABLE") + val schema = delegate.computeAlterTableSchema(txnTable.schema, changes.toSeq) if (schema.fields.isEmpty) { throw new IllegalArgumentException(s"Cannot drop all fields") From 5079d84f4a7905557d559a33777a2aecb05c7c1b Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Tue, 14 Apr 2026 07:39:27 +0000 Subject: [PATCH 15/22] Improve comments in schema evolution --- .../spark/sql/catalyst/analysis/V2TableReference.scala | 6 ++++++ .../InMemoryRowLevelOperationTableCatalog.scala | 10 ++++------ .../org/apache/spark/sql/connector/catalog/txns.scala | 6 +++--- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala index 76226056ffe65..f459706a690bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala @@ -92,6 +92,8 @@ private[sql] object V2TableReference { create(relation, TemporaryViewContext(viewName)) } + // V2TableReference nodes in the transaction context are produced by + // UnresolveTransactionRelations which unresolves already resolved relations. def createForTransaction(relation: DataSourceV2Relation): V2TableReference = { create(relation, TransactionContext) } @@ -125,6 +127,10 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { } private def validateLoadedTableInTransaction(table: Table, ref: V2TableReference): Unit = { + // Do not allow schema evolution to pre-analysed dataframes that are later used in + // transactional writes. This is because the entire plans was built based on the original schema + // and any schema change would make the plan structurally invalid. This is inline with the + // semantics of SPARK-54444. val dataErrors = V2TableUtil.validateCapturedColumns( table = table, originCols = ref.info.columns, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index bdf19e0e9d355..4a38285b685e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -80,7 +80,7 @@ class InMemoryRowLevelOperationTableCatalog /** * Computes the schema that would result from applying `changes` to `currentSchema`. - * Overriding this allows subclasses to simulate catalogs that selectively ignore some changes + * Can be overridden by subclasses to simulate catalogs that selectively ignore changes * (e.g. [[PartialSchemaEvolutionCatalog]]). */ def computeAlterTableSchema(currentSchema: StructType, changes: Seq[TableChange]): StructType = { @@ -105,9 +105,10 @@ class PartialSchemaEvolutionCatalog extends InMemoryRowLevelOperationTableCatalo case _ => false } val properties = CatalogV2Util.applyPropertiesChanges(table.properties, propertyChanges) + val schema = computeAlterTableSchema(table.schema, changes.toSeq) val newTable = new InMemoryRowLevelOperationTable( name = table.name, - schema = table.schema, + schema = schema, partitioning = table.partitioning, properties = properties, constraints = table.constraints) @@ -116,10 +117,7 @@ class PartialSchemaEvolutionCatalog extends InMemoryRowLevelOperationTableCatalo newTable } - // When used inside a transaction, TxnTableCatalog.alterTable uses this method to compute - // the resulting schema instead of calling CatalogV2Util.applySchemaChanges directly. - // Returning the current schema unchanged mirrors the behaviour of alterTable above (silently - // ignore all column changes), so ResolveSchemaEvolution can still detect pending changes. + // Ignores all schema changes and returns the current schema unchanged. override def computeAlterTableSchema( currentSchema: StructType, changes: Seq[TableChange]): StructType = currentSchema diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 1f0d36f925625..f19e360047d26 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -129,9 +129,9 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T // it needs to be transactional. The schema changes are only propagated to the delegate at // commit time. // - // We delegate schema computation to the underlying catalog so that catalogs that selectively - // ignore some changes (e.g. PartialSchemaEvolutionCatalog) have the same behaviour inside a - // transaction. This lets ResolveSchemaEvolution detect pending changes correctly. + // We delegate schema computation to the underlying catalog so that catalogs with special + // handling (e.g. PartialSchemaEvolutionCatalog) have the same behaviour inside a + // transaction. val txnTable = tables.get(ident) val schema = delegate.computeAlterTableSchema(txnTable.schema, changes.toSeq) From 5ee35620a66972d5a54df16398169167d7970349 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Tue, 14 Apr 2026 13:46:48 +0000 Subject: [PATCH 16/22] Mark new APIs as evolving + minor cleanup --- .../sql/connector/catalog/TransactionalCatalogPlugin.java | 2 ++ .../sql/connector/catalog/transactions/Transaction.java | 2 ++ .../sql/connector/catalog/transactions/TransactionInfo.java | 3 +++ .../org/apache/spark/sql/execution/QueryExecution.scala | 5 ++--- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java index 34a4fc68e9649..daa3176dcbba5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.catalog; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.catalog.transactions.Transaction; import org.apache.spark.sql.connector.catalog.transactions.TransactionInfo; @@ -28,6 +29,7 @@ * * @since 4.2.0 */ +@Evolving public interface TransactionalCatalogPlugin extends CatalogPlugin { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java index 80513aff31506..77044c6202fbe 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.catalog.transactions; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.catalog.CatalogPlugin; import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin; @@ -32,6 +33,7 @@ * * @since 4.2.0 */ +@Evolving public interface Transaction extends Closeable { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java index a9c17d4b88274..3e6979cec469f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java @@ -17,11 +17,14 @@ package org.apache.spark.sql.connector.catalog.transactions; +import org.apache.spark.annotation.Evolving; + /** * Metadata about a transaction. * * @since 4.2.0 */ +@Evolving public interface TransactionInfo { /** * Returns a unique identifier for this transaction. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index bd7731a96a3ab..e5a71943d9a3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -286,9 +286,8 @@ class QueryExecution( // the transaction was active. if (transactionOpt.isDefined) { normalized - } - else { - // clone the plan to avoid sharing the plan instance between different stages like + } else { + // Clone the plan to avoid sharing the plan instance between different stages like // analyzing, optimizing and planning. sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) } From 9e0e030feee43321d269b2d39758a8163dc94030 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 15 Apr 2026 10:53:20 +0000 Subject: [PATCH 17/22] Add TODO plus nit --- .../spark/sql/connector/catalog/txns.scala | 1 + .../spark/sql/execution/QueryExecution.scala | 3 +- .../connector/StreamingTransactionSuite.scala | 213 ++++++++++++++++++ 3 files changed, 215 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index f19e360047d26..157c9a82d6f2f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -139,6 +139,7 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T throw new IllegalArgumentException(s"Cannot drop all fields") } + // TODO: We need to pass all tracked predicates to the new TXN table. val newTxnTable = new TxnTable(txnTable.delegate, schema) tables.put(ident, newTxnTable) newTxnTable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index e5a71943d9a3e..d15a0dd7b5a65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -233,8 +233,7 @@ class QueryExecution( // for eagerly executed commands we mark this place as beginning of execution. tracker.setReadyForExecution() val (qe, result) = QueryExecution.runCommand( - sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode), - analyzerOpt = Some(analyzer)) + sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode), Some(analyzer)) CommandResult( qe.analyzed.output, qe.commandExecuted, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala new file mode 100644 index 0000000000000..bd3a8fc1307b2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.{Committed, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, PhysicalWriteInfo, Write, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} +import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.sources.PackedRowWriterFactory +import org.apache.spark.sql.internal.connector.SimpleTableProvider +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.tags.SlowSQLTest + +/** + * Tests that structured streaming micro-batch writes participate in the DSv2 transaction API. + * + * The V2 streaming path is: + * WriteToMicroBatchDataSource (logical) + * -> V2Writes rule -> WriteToDataSourceV2(MicroBatchWrite) (logical) + * -> WriteToDataSourceV2Exec (physical, implements TransactionalExec) + * + * Each micro-batch runs in its own IncrementalExecution, so transactionOpt is evaluated + * fresh per batch. The transaction is committed inside WriteToDataSourceV2Exec.run() after + * writeWithV2 completes, and aborted if writeWithV2 throws. + */ +@SlowSQLTest +class StreamingTransactionSuite extends StreamTest with BeforeAndAfter { + import testImplicits._ + + private val tableIdent = Identifier.of(Array("ns1"), "test_table") + private val tableNameAsString = "cat.ns1.test_table" + + before { + spark.conf.set("spark.sql.catalog.cat", + classOf[InMemoryRowLevelOperationTableCatalog].getName) + sql("CREATE NAMESPACE cat.ns1") + sql(s"CREATE TABLE $tableNameAsString (value INT) USING foo") + } + + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.clear() + sqlContext.streams.active.foreach(_.stop()) + } + + private def catalog: InMemoryRowLevelOperationTableCatalog = + spark.sessionState.catalogManager.catalog("cat") + .asInstanceOf[InMemoryRowLevelOperationTableCatalog] + + private def delegateTable: InMemoryRowLevelOperationTable = + catalog.loadTable(tableIdent).asInstanceOf[InMemoryRowLevelOperationTable] + + test("streaming micro-batch append commits a transaction") { + val stream = MemoryStream[Int] + + withTempDir { checkpointDir => + val query = stream.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .toTable(tableNameAsString) + + try { + stream.addData(1, 2, 3) + query.processAllAvailable() + + val txn = catalog.lastTransaction + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(delegateTable.version() === "2") + + checkAnswer( + spark.table(tableNameAsString), + Seq(Row(1), Row(2), Row(3))) + } finally { + query.stop() + } + } + } + + test("each micro-batch gets a fresh transaction") { + val stream = MemoryStream[Int] + + withTempDir { checkpointDir => + val query = stream.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .toTable(tableNameAsString) + + try { + stream.addData(1, 2, 3) + query.processAllAvailable() + val txn1 = catalog.lastTransaction + assert(txn1.currentState === Committed) + assert(txn1.isClosed) + assert(delegateTable.version() === "2") + + stream.addData(4, 5) + query.processAllAvailable() + val txn2 = catalog.lastTransaction + assert(txn2.currentState === Committed) + assert(txn2.isClosed) + assert(txn2 ne txn1, "each batch must open a fresh transaction") + assert(delegateTable.version() === "3") + + checkAnswer( + spark.table(tableNameAsString), + Seq(Row(1), Row(2), Row(3), Row(4), Row(5))) + } finally { + query.stop() + } + } + } + + test("no transaction is started when the catalog is not transactional") { + // Writing to a non-transactional catalog (session catalog / parquet) must not + // open a transaction. Verify by confirming catalog.lastTransaction is untouched. + val initialLastTxn = catalog.lastTransaction // null at start + + withTempDir { dir => + val stream = MemoryStream[Int] + val query = stream.toDF() + .writeStream + .format("parquet") + .option("checkpointLocation", dir.getCanonicalPath + "/checkpoint") + .option("path", dir.getCanonicalPath + "/data") + .start() + + try { + stream.addData(1, 2, 3) + query.processAllAvailable() + + // our TxnTableCatalog was not involved - lastTransaction must be unchanged + assert(catalog.lastTransaction === initialLastTxn) + } finally { + query.stop() + } + } + } + + test("no transaction is started for an anonymous V2 sink (catalog = None)") { + // An anonymous V2 sink has DataSourceV2Relation.catalog == None (no catalog/ident). + // UnresolveTransactionRelations skips it since catalog doesn't match any + // TransactionalCatalogPlugin, so transactionOpt returns None and no transaction is opened. + val initialLastTxn = catalog.lastTransaction // null at start + + withTempDir { dir => + val stream = MemoryStream[Int] + val query = stream.toDF() + .writeStream + .format(classOf[NoOpV2SinkProvider].getName) + .option("checkpointLocation", dir.getCanonicalPath + "/checkpoint") + .start() + + try { + stream.addData(1, 2, 3) + query.processAllAvailable() + + // Anonymous V2 sink: no catalog involved, no transaction must be opened. + assert(catalog.lastTransaction === initialLastTxn) + } finally { + query.stop() + } + } + } +} + +/** + * A no-op V2 streaming sink with no catalog or identifier (anonymous sink). + * Used to verify that anonymous V2 sinks do not open transactions. + */ +class NoOpV2SinkProvider extends SimpleTableProvider { + override def getTable(options: CaseInsensitiveStringMap): Table = { + new Table with SupportsWrite { + override def name(): String = "noop-v2-sink" + override def schema(): StructType = StructType(Nil) + override def capabilities(): util.Set[TableCapability] = + util.EnumSet.of(TableCapability.STREAMING_WRITE) + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = + new WriteBuilder { + override def build(): Write = new Write { + override def toStreaming: StreamingWrite = new StreamingWrite { + override def createStreamingWriterFactory( + info2: PhysicalWriteInfo): StreamingDataWriterFactory = PackedRowWriterFactory + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + } + } + } + } + } +} From c1fdd95b35ac54ca117333858c384077e2d7017e Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 15 Apr 2026 11:00:26 +0000 Subject: [PATCH 18/22] Remove StreamingTransactionSuite --- .../connector/StreamingTransactionSuite.scala | 213 ------------------ 1 file changed, 213 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala deleted file mode 100644 index bd3a8fc1307b2..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector - -import java.util - -import org.scalatest.BeforeAndAfter - -import org.apache.spark.sql.Row -import org.apache.spark.sql.connector.catalog.{Committed, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, SupportsWrite, Table, TableCapability} -import org.apache.spark.sql.connector.write.{LogicalWriteInfo, PhysicalWriteInfo, Write, WriteBuilder, WriterCommitMessage} -import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} -import org.apache.spark.sql.execution.streaming.runtime.MemoryStream -import org.apache.spark.sql.execution.streaming.sources.PackedRowWriterFactory -import org.apache.spark.sql.internal.connector.SimpleTableProvider -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.tags.SlowSQLTest - -/** - * Tests that structured streaming micro-batch writes participate in the DSv2 transaction API. - * - * The V2 streaming path is: - * WriteToMicroBatchDataSource (logical) - * -> V2Writes rule -> WriteToDataSourceV2(MicroBatchWrite) (logical) - * -> WriteToDataSourceV2Exec (physical, implements TransactionalExec) - * - * Each micro-batch runs in its own IncrementalExecution, so transactionOpt is evaluated - * fresh per batch. The transaction is committed inside WriteToDataSourceV2Exec.run() after - * writeWithV2 completes, and aborted if writeWithV2 throws. - */ -@SlowSQLTest -class StreamingTransactionSuite extends StreamTest with BeforeAndAfter { - import testImplicits._ - - private val tableIdent = Identifier.of(Array("ns1"), "test_table") - private val tableNameAsString = "cat.ns1.test_table" - - before { - spark.conf.set("spark.sql.catalog.cat", - classOf[InMemoryRowLevelOperationTableCatalog].getName) - sql("CREATE NAMESPACE cat.ns1") - sql(s"CREATE TABLE $tableNameAsString (value INT) USING foo") - } - - after { - spark.sessionState.catalogManager.reset() - spark.sessionState.conf.clear() - sqlContext.streams.active.foreach(_.stop()) - } - - private def catalog: InMemoryRowLevelOperationTableCatalog = - spark.sessionState.catalogManager.catalog("cat") - .asInstanceOf[InMemoryRowLevelOperationTableCatalog] - - private def delegateTable: InMemoryRowLevelOperationTable = - catalog.loadTable(tableIdent).asInstanceOf[InMemoryRowLevelOperationTable] - - test("streaming micro-batch append commits a transaction") { - val stream = MemoryStream[Int] - - withTempDir { checkpointDir => - val query = stream.toDF() - .writeStream - .option("checkpointLocation", checkpointDir.getCanonicalPath) - .toTable(tableNameAsString) - - try { - stream.addData(1, 2, 3) - query.processAllAvailable() - - val txn = catalog.lastTransaction - assert(txn.currentState === Committed) - assert(txn.isClosed) - assert(delegateTable.version() === "2") - - checkAnswer( - spark.table(tableNameAsString), - Seq(Row(1), Row(2), Row(3))) - } finally { - query.stop() - } - } - } - - test("each micro-batch gets a fresh transaction") { - val stream = MemoryStream[Int] - - withTempDir { checkpointDir => - val query = stream.toDF() - .writeStream - .option("checkpointLocation", checkpointDir.getCanonicalPath) - .toTable(tableNameAsString) - - try { - stream.addData(1, 2, 3) - query.processAllAvailable() - val txn1 = catalog.lastTransaction - assert(txn1.currentState === Committed) - assert(txn1.isClosed) - assert(delegateTable.version() === "2") - - stream.addData(4, 5) - query.processAllAvailable() - val txn2 = catalog.lastTransaction - assert(txn2.currentState === Committed) - assert(txn2.isClosed) - assert(txn2 ne txn1, "each batch must open a fresh transaction") - assert(delegateTable.version() === "3") - - checkAnswer( - spark.table(tableNameAsString), - Seq(Row(1), Row(2), Row(3), Row(4), Row(5))) - } finally { - query.stop() - } - } - } - - test("no transaction is started when the catalog is not transactional") { - // Writing to a non-transactional catalog (session catalog / parquet) must not - // open a transaction. Verify by confirming catalog.lastTransaction is untouched. - val initialLastTxn = catalog.lastTransaction // null at start - - withTempDir { dir => - val stream = MemoryStream[Int] - val query = stream.toDF() - .writeStream - .format("parquet") - .option("checkpointLocation", dir.getCanonicalPath + "/checkpoint") - .option("path", dir.getCanonicalPath + "/data") - .start() - - try { - stream.addData(1, 2, 3) - query.processAllAvailable() - - // our TxnTableCatalog was not involved - lastTransaction must be unchanged - assert(catalog.lastTransaction === initialLastTxn) - } finally { - query.stop() - } - } - } - - test("no transaction is started for an anonymous V2 sink (catalog = None)") { - // An anonymous V2 sink has DataSourceV2Relation.catalog == None (no catalog/ident). - // UnresolveTransactionRelations skips it since catalog doesn't match any - // TransactionalCatalogPlugin, so transactionOpt returns None and no transaction is opened. - val initialLastTxn = catalog.lastTransaction // null at start - - withTempDir { dir => - val stream = MemoryStream[Int] - val query = stream.toDF() - .writeStream - .format(classOf[NoOpV2SinkProvider].getName) - .option("checkpointLocation", dir.getCanonicalPath + "/checkpoint") - .start() - - try { - stream.addData(1, 2, 3) - query.processAllAvailable() - - // Anonymous V2 sink: no catalog involved, no transaction must be opened. - assert(catalog.lastTransaction === initialLastTxn) - } finally { - query.stop() - } - } - } -} - -/** - * A no-op V2 streaming sink with no catalog or identifier (anonymous sink). - * Used to verify that anonymous V2 sinks do not open transactions. - */ -class NoOpV2SinkProvider extends SimpleTableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { - new Table with SupportsWrite { - override def name(): String = "noop-v2-sink" - override def schema(): StructType = StructType(Nil) - override def capabilities(): util.Set[TableCapability] = - util.EnumSet.of(TableCapability.STREAMING_WRITE) - override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = - new WriteBuilder { - override def build(): Write = new Write { - override def toStreaming: StreamingWrite = new StreamingWrite { - override def createStreamingWriterFactory( - info2: PhysicalWriteInfo): StreamingDataWriterFactory = PackedRowWriterFactory - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - } - } - } - } - } -} From 8bfb2ae058ca0becb25da40afef5c1448ab0e735 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 15 Apr 2026 13:16:24 +0000 Subject: [PATCH 19/22] More comments and renames --- .../spark/sql/connector/catalog/txns.scala | 33 +++++++++++++++---- .../spark/sql/execution/QueryExecution.scala | 27 +++++++-------- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 157c9a82d6f2f..7b55e20c61676 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -48,11 +48,13 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction { this.state = Committed } + // This is idempotent since nested QEs can cause multiple aborts. override def abort(): Unit = { if (state == Committed || state == Aborted) return this.state = Aborted } + // This is idempotent since nested QEs can cause multiple aborts. override def close(): Unit = { if (!closed) { catalog.clearActiveTransaction() @@ -64,6 +66,9 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction { // A special table used in row-level operation transactions. It inherits data // from the base table upon construction and propagates staged transaction state // back after an explicit commit. +// Note, the in-memory data store does not handle concurrency at the moment. The assumes that the +// underlying delegate table cannot change from concurrent transactions. Data sources need to +// implement isolation semantics and make sure they are enforced. class TxnTable(val delegate: InMemoryRowLevelOperationTable, schema: StructType) extends InMemoryRowLevelOperationTable( delegate.name, @@ -80,10 +85,13 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable, schema: StructType) // A tracker of filters used in each scan. val scanEvents = new ArrayBuffer[Array[Filter]]() + // Record scan events. This is invoked when building a scan for the particular table. override protected def recordScanEvent(filters: Array[Filter]): Unit = { scanEvents += filters } + // Perform commit if there are any changes. This push metadata and data changes to the + // delegate table. def commit(): Unit = { if (version() != initialVersion) { delegate.dataMap.clear() @@ -98,8 +106,11 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable, schema: StructType) } } -// A special table catalog used in row-level operation transactions. -// Table changes are initially staged in memory and propagated only after an explicit commit. +// A special table catalog used in row-level operation transactions. The lifecycle of this catalog +// is tied to the transaction. A new catalog instance is created at the beginning of a transaction +// and discarded at the end. The catalog is responsible for pinning all tables involved in the +// transaction. Table changes are initially staged in memory and propagated only after an explicit +// commit. class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends TableCatalog { private val tables: util.Map[Identifier, TxnTable] = new ConcurrentHashMap[Identifier, TxnTable]() @@ -112,6 +123,9 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T throw new UnsupportedOperationException() } + // This is where the table pinning logic should occur. In this implementation, a tables is loaded + // (pinned) the first time is accessed. All subsequent accesses should return the same pinned + // table. override def loadTable(ident: Identifier): Table = { tables.computeIfAbsent(ident, _ => { val table = delegate.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] @@ -119,11 +133,6 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T }) } - override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { - delegate.createTable(ident, tableInfo) - loadTable(ident) - } - override def alterTable(ident: Identifier, changes: TableChange*): Table = { // AlterTable may be called by ResolveSchemaEvolution when schema evolution is enabled. Thus, // it needs to be transactional. The schema changes are only propagated to the delegate at @@ -145,6 +154,13 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T newTxnTable } + // TODO: Currently not transactional. Should be revised when Atomic CTAS/RTAS is implemented. + override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { + delegate.createTable(ident, tableInfo) + loadTable(ident) + } + + // TODO: Currently not transactional. Should be revised when Atomic CTAS/RTAS is implemented. override def dropTable(ident: Identifier): Boolean = { tables.remove(ident) delegate.dropTable(ident) @@ -154,10 +170,13 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T throw new UnsupportedOperationException() } + // Invoke commit for all tables participated in the transaction. If a table is read-only + // this is a no-op. def commit(): Unit = { tables.values.forEach(table => table.commit()) } + // Clear transaction context. def clearActiveTransaction(): Unit = { delegate.lastTransaction = delegate.transaction delegate.transaction = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index d15a0dd7b5a65..e07d0147205a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -147,7 +147,7 @@ class QueryExecution( } } - def assertSupported(): Unit = executeWithTransactionContext { + def assertSupported(): Unit = withAbortTransactionOnFailure { if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { UnsupportedOperationChecker.checkForBatch(analyzed) } @@ -197,7 +197,7 @@ class QueryExecution( } } - def analyzed: LogicalPlan = executeWithTransactionContext { + def analyzed: LogicalPlan = withAbortTransactionOnFailure { lazyAnalyzed.get } @@ -209,7 +209,7 @@ class QueryExecution( } } - def commandExecuted: LogicalPlan = executeWithTransactionContext { + def commandExecuted: LogicalPlan = withAbortTransactionOnFailure { lazyCommandExecuted.get } @@ -271,7 +271,7 @@ class QueryExecution( } // The plan that has been normalized by custom rules, so that it's more likely to hit cache. - def normalized: LogicalPlan = executeWithTransactionContext { + def normalized: LogicalPlan = withAbortTransactionOnFailure { lazyNormalized.get } @@ -293,7 +293,7 @@ class QueryExecution( } } - def withCachedData: LogicalPlan = executeWithTransactionContext { + def withCachedData: LogicalPlan = withAbortTransactionOnFailure { lazyWithCachedData.get } @@ -317,7 +317,7 @@ class QueryExecution( } } - def optimizedPlan: LogicalPlan = executeWithTransactionContext { + def optimizedPlan: LogicalPlan = withAbortTransactionOnFailure { lazyOptimizedPlan.get } @@ -335,7 +335,7 @@ class QueryExecution( attachTransaction(plan) } - def sparkPlan: SparkPlan = executeWithTransactionContext { + def sparkPlan: SparkPlan = withAbortTransactionOnFailure { lazySparkPlan.get } @@ -358,7 +358,7 @@ class QueryExecution( // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - def executedPlan: SparkPlan = executeWithTransactionContext { + def executedPlan: SparkPlan = withAbortTransactionOnFailure { lazyExecutedPlan.get } @@ -378,7 +378,7 @@ class QueryExecution( * Given QueryExecution is not a public class, end users are discouraged to use this: please * use `Dataset.rdd` instead where conversion will be applied. */ - def toRdd: RDD[InternalRow] = executeWithTransactionContext { + def toRdd: RDD[InternalRow] = withAbortTransactionOnFailure { lazyToRdd.get } @@ -583,13 +583,10 @@ class QueryExecution( } /** - * Executes the given block with the transaction context if exists. If there is an exception - * thrown during the execution, the transaction will be aborted. - * - * Note: The transaction is not committed in this method. The caller should commit the - * transaction if the execution is successful. + * Runs the given block, aborting the active transaction if an exception is thrown. + * If no transaction is active, the block is executed as-is. */ - private def executeWithTransactionContext[T](block: => T): T = transactionOpt match { + private def withAbortTransactionOnFailure[T](block: => T): T = transactionOpt match { case Some(transaction) => try block catch { case e: Throwable => TransactionUtils.abort(transaction); throw e } From 5988c4497ecc992d1f3988b6d5d84aedf7aed6b7 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Fri, 17 Apr 2026 12:02:34 +0000 Subject: [PATCH 20/22] Test coverage for SQL scripting --- ...nMemoryRowLevelOperationTableCatalog.scala | 5 + .../spark/sql/connector/catalog/txns.scala | 6 +- .../sql/scripting/SqlScriptingE2eSuite.scala | 145 ++++++++++++++++++ 3 files changed, 155 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index 4a38285b685e8..4e5e1e7c8c6e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.catalog +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} import org.apache.spark.sql.types.StructType @@ -30,6 +32,9 @@ class InMemoryRowLevelOperationTableCatalog var transaction: Txn = _ // The last completed transaction. var lastTransaction: Txn = _ + // All transactions in order (committed and aborted), allowing per-statement + // validation in SQL scripting tests. + val seenTransactions: ArrayBuffer[Txn] = new ArrayBuffer[Txn]() override def beginTransaction(info: TransactionInfo): Transaction = { assert(transaction == null || transaction.currentState != Active) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 7b55e20c61676..49ddeb2c7c809 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -22,6 +22,8 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.util.QuotingUtils import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType @@ -178,7 +180,9 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T // Clear transaction context. def clearActiveTransaction(): Unit = { - delegate.lastTransaction = delegate.transaction + val txn = delegate.transaction + delegate.lastTransaction = txn + delegate.seenTransactions += txn delegate.transaction = null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala index 9996bec44e232..01db9b5d07018 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala @@ -21,8 +21,10 @@ import org.apache.spark.SparkConf import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.plans.logical.CompoundBody import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLConf +import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTableCatalog, Txn, TxnTable, TxnTableCatalog} import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -47,6 +49,27 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession { } // Helpers + private def withCatalog( + name: String)( + f: InMemoryRowLevelOperationTableCatalog => Unit): Unit = { + withSQLConf(s"spark.sql.catalog.$name" -> + classOf[InMemoryRowLevelOperationTableCatalog].getName) { + val catalog = spark.sessionState.catalogManager + .catalog(name) + .asInstanceOf[InMemoryRowLevelOperationTableCatalog] + f(catalog) + } + } + + private def loadTxnTable( + txn: Txn, + tableName: String, + namespace: Array[String] = Array("ns1")): TxnTable = + txn.catalog + .asInstanceOf[TxnTableCatalog] + .loadTable(Identifier.of(namespace, tableName)) + .asInstanceOf[TxnTable] + private def verifySqlScriptResult( sqlText: String, expected: Seq[Row], @@ -174,6 +197,128 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession { } } + test("multi statement with transactional checks - insert then delete") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software'); + | DELETE FROM cat.ns1.t + | WHERE pk IN (SELECT pk FROM cat.ns1.t WHERE dep = 'hr'); + | SELECT * FROM cat.ns1.t; + |END + |""".stripMargin + + verifySqlScriptResult(sqlScript, Seq(Row(2, 200, "software"))) + + // Each DML statement in a script runs in its own independent QE and transaction. + assert(catalog.seenTransactions.size === 2) + assert(catalog.seenTransactions.forall(t => + t.currentState === Committed && t.isClosed)) + + // The DELETE subquery scans the table with a dep='hr' predicate; verify it was tracked. + val deleteTxnTable = loadTxnTable(catalog.seenTransactions(1), "t") + assert(deleteTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + } + } + } + + test("multi statement with transactional checks - second statement fails") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software'); + | DELETE FROM cat.ns1.t WHERE nonexistent_column = 1; + |END + |""".stripMargin + + checkError( + exception = intercept[AnalysisException] { + spark.sql(sqlScript).collect() + }, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map( + "objectName" -> "`nonexistent_column`", + "proposal" -> ".*"), + matchPVals = true, + queryContext = Array(ExpectedContext("nonexistent_column"))) + + // INSERT committed; DELETE was aborted because analysis failed on the bad column. + assert(catalog.seenTransactions.size === 2) + assert(catalog.seenTransactions(0).currentState === Committed) + assert(catalog.seenTransactions(0).isClosed) + assert(catalog.seenTransactions(1).currentState === Aborted) + assert(catalog.seenTransactions(1).isClosed) + assert(catalog.lastTransaction.currentState === Aborted) + } + } + } + + test("multi statement with transactional checks - insert, merge, update") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t", "cat.ns1.src") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | CREATE TABLE cat.ns1.src (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software'), (3, 300, 'hr'); + | INSERT INTO cat.ns1.src VALUES (1, 150, 'hr'), (4, 400, 'finance'); + | MERGE INTO cat.ns1.t AS t + | USING cat.ns1.src AS s + | ON t.pk = s.pk + | WHEN MATCHED THEN UPDATE SET salary = s.salary + | WHEN NOT MATCHED THEN INSERT (pk, salary, dep) + | VALUES (s.pk, s.salary, s.dep); + | UPDATE cat.ns1.t SET salary = salary + 50 WHERE dep = 'software'; + | SELECT * FROM cat.ns1.t ORDER BY pk; + |END + |""".stripMargin + + verifySqlScriptResult( + sqlScript, + Seq( + Row(1, 150, "hr"), + Row(2, 250, "software"), + Row(3, 300, "hr"), + Row(4, 400, "finance"))) + + // INSERT (x2), MERGE, and UPDATE each run in their own independent QE and transaction. + assert(catalog.seenTransactions.size === 4) + assert(catalog.seenTransactions.forall(t => t.currentState === Committed && t.isClosed)) + + def txnTable(txnIdx: Int): TxnTable = + loadTxnTable(catalog.seenTransactions(txnIdx), "t") + + // Both inserts are pure writes - no scan. + assert(txnTable(0).scanEvents.isEmpty) + assert(txnTable(1).scanEvents.isEmpty) + + // MERGE scans the full target table. The join is on pk (not the partition column). + assert(txnTable(2).scanEvents.nonEmpty) + assert(txnTable(2).scanEvents.flatten.isEmpty) + + // UPDATE with WHERE dep='software' pushes an equality predicate on the partition column. + assert(txnTable(3).scanEvents.flatten.exists { + case sources.EqualTo("dep", "software") => true + case _ => false + }) + } + } + } + test("script without result statement") { val sqlScript = """ From e79806ea5a92544d882b1ef90ece5d4a7d880624 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Fri, 17 Apr 2026 12:32:13 +0000 Subject: [PATCH 21/22] Extra SQL scripting tests --- .../sql/scripting/SqlScriptingE2eSuite.scala | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala index 01db9b5d07018..bdf47dc96c3a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala @@ -319,6 +319,66 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession { } } + test("loop with transactional checks - each iteration runs in its own transaction") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | DECLARE i INT = 1; + | CREATE TABLE + | cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | WHILE i <= 3 DO + | INSERT INTO cat.ns1.t VALUES (i, i * 100, 'hr'); + | SET i = i + 1; + | END WHILE; + | SELECT * FROM cat.ns1.t ORDER BY pk; + |END + |""".stripMargin + + verifySqlScriptResult( + sqlScript, + Seq(Row(1, 100, "hr"), Row(2, 200, "hr"), Row(3, 300, "hr"))) + + // Each loop iteration's INSERT runs in its own independent transaction. + assert(catalog.seenTransactions.size === 3) + assert(catalog.seenTransactions.forall(t => t.currentState === Committed && t.isClosed)) + } + } + } + + test("continue handler with transactional checks - handler DML runs in its own transaction") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | DECLARE CONTINUE HANDLER FOR DIVIDE_BY_ZERO + | BEGIN + | INSERT INTO cat.ns1.t VALUES (-1, -1, 'error'); + | END; + | CREATE TABLE + | cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'); + | SELECT 1/0; + | INSERT INTO cat.ns1.t VALUES (2, 200, 'software'); + | SELECT * FROM cat.ns1.t ORDER BY pk; + |END + |""".stripMargin + + verifySqlScriptResult( + sqlScript, + Seq(Row(-1, -1, "error"), Row(1, 100, "hr"), Row(2, 200, "software"))) + + // INSERT(1), handler INSERT(-1), INSERT(2) - each in its own transaction. + assert(catalog.seenTransactions.size === 3) + assert(catalog.seenTransactions.forall(t => t.currentState === Committed && t.isClosed)) + } + } + } + test("script without result statement") { val sqlScript = """ From 14474cd3a439862c05c9821da83fd9c8fb7a90d0 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Fri, 17 Apr 2026 13:58:11 +0000 Subject: [PATCH 22/22] Fix lint --- .../scala/org/apache/spark/sql/connector/catalog/txns.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 49ddeb2c7c809..232b623174996 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -22,8 +22,6 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException -import org.apache.spark.sql.catalyst.util.QuotingUtils import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType