diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java new file mode 100644 index 0000000000000..daa3176dcbba5 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.transactions.Transaction; +import org.apache.spark.sql.connector.catalog.transactions.TransactionInfo; + +/** + * A {@link CatalogPlugin} that supports transactions. + *

+ * Catalogs that implement this interface opt in to transactional query execution. A catalog + * implementing this interface is responsible for starting transactions. + * + * @since 4.2.0 + */ +@Evolving +public interface TransactionalCatalogPlugin extends CatalogPlugin { + + /** + * Begins a new transaction and returns a {@link Transaction} representing it. + * + * @param info metadata about the transaction being started. + */ + Transaction beginTransaction(TransactionInfo info); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java new file mode 100644 index 0000000000000..77044c6202fbe --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.transactions; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin; + +import java.io.Closeable; + +/** + * Represents a transaction. + *

+ * Spark begins a transaction with {@link TransactionalCatalogPlugin#beginTransaction} and + * executes read/write operations against the transaction's catalog. On success, Spark + * calls {@link #commit()}; on failure, Spark calls {@link #abort()}. In both cases Spark + * subsequently calls {@link #close()} to release resources. + * + * @since 4.2.0 + */ +@Evolving +public interface Transaction extends Closeable { + + /** + * Returns the catalog associated with this transaction. This catalog is responsible for tracking + * read/write operations that occur within the boundaries of a transaction. This allows + * connectors to perform conflict resolution at commit time. + */ + CatalogPlugin catalog(); + + /** + * Commits the transaction. All writes performed under it become visible to other readers. + *

+ * The connector is responsible for detecting and resolving conflicting commits or throwing + * an exception if resolution is not possible. + *

+ * This method will be called exactly once per transaction. Spark calls {@link #close()} + * immediately after this method returns. + * + * @throws IllegalStateException if the transaction has already been committed or aborted. + */ + void commit(); + + /** + * Aborts the transaction, discarding any staged changes. + *

+ * This method must be idempotent. If the transaction has already been committed or aborted, + * invoking it must have no effect. + *

+ * Spark calls {@link #close()} immediately after this method returns. + */ + void abort(); + + /** + * Releases any resources held by this transaction. + *

+ * Spark always calls this method after {@link #commit()} or {@link #abort()}, regardless of + * whether those methods succeed or not. + *

+ * This method must be idempotent. If the transaction has already been closed, + * invoking it must have no effect. + */ + @Override + void close(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java new file mode 100644 index 0000000000000..3e6979cec469f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.transactions; + +import org.apache.spark.annotation.Evolving; + +/** + * Metadata about a transaction. + * + * @since 4.2.0 + */ +@Evolving +public interface TransactionInfo { + /** + * Returns a unique identifier for this transaction. + */ + String id(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3b4d725840935..09928553839ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -336,6 +336,30 @@ class Analyzer( } } + /** + * Returns a copy of this analyzer that uses the given [[CatalogManager]] for all catalog + * lookups. All other configuration (extended rules, checks, etc.) is preserved. Used by + * [[QueryExecution]] to create a per-query analyzer for transactional operations for + * transaction-aware catalog resolution. + */ + def withCatalogManager(newCatalogManager: CatalogManager): Analyzer = { + val self = this + new Analyzer(newCatalogManager, sharedRelationCache) { + override val hintResolutionRules: Seq[Rule[LogicalPlan]] = self.hintResolutionRules + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = self.extendedResolutionRules + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = self.postHocResolutionRules + override val extendedCheckRules: Seq[LogicalPlan => Unit] = self.extendedCheckRules + override val singlePassResolverExtensions: Seq[ResolverExtension] = + self.singlePassResolverExtensions + override val singlePassMetadataResolverExtensions: Seq[ResolverExtension] = + self.singlePassMetadataResolverExtensions + override val singlePassPostHocResolutionRules: Seq[Rule[LogicalPlan]] = + self.singlePassPostHocResolutionRules + override val singlePassExtendedResolutionChecks: Seq[LogicalPlan => Unit] = + self.singlePassExtendedResolutionChecks + } + } + override def execute(plan: LogicalPlan): LogicalPlan = { AnalysisContext.withNewAnalysisContext { executeSameContext(plan) @@ -437,7 +461,9 @@ class Analyzer( Batch("Simple Sanity Check", Once, LookupFunctions), Batch("Keep Legacy Outputs", Once, - KeepLegacyOutputs) + KeepLegacyOutputs), + Batch("Unresolve Relations", Once, + new UnresolveTransactionRelations(catalogManager)) ) override def batches: Seq[Batch] = earlyBatches ++ Seq( @@ -1005,9 +1031,10 @@ class Analyzer( } } - // Resolve V2TableReference nodes in a plan. V2TableReference is only created for temp views - // (via V2TableReference.createForTempView), so we only need to resolve it when returning - // the plan of temp views (in resolveViews and unwrapRelationPlan). + // Resolve V2TableReference nodes created for: + // 1 Temp views (via createForTempView). + // 2. Transaction references (via createForTransaction). These are resolved by a + // separate analysis batch in the transaction-aware analyzer instance. private def resolveTableReferences(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsUp { case r: V2TableReference => relationResolution.resolveReference(r) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala index e86248febd2eb..fd0394098b0ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -391,6 +391,11 @@ class RelationResolution( } private def getOrLoadRelation(ref: V2TableReference): LogicalPlan = { + // Skip cache when a transaction is active. + if (catalogManager.transaction.isDefined) { + return loadRelation(ref) + } + val key = toCacheKey(ref.catalog, ref.identifier) relationCache.get(key) match { case Some(cached) => @@ -403,9 +408,18 @@ class RelationResolution( } private def loadRelation(ref: V2TableReference): LogicalPlan = { - val table = ref.catalog.loadTable(ref.identifier) + // Resolve catalog. When a transaction is active we return the transaction + // aware catalog instance. + val resolvedCatalog = catalogManager.catalog(ref.catalog.name).asTableCatalog + val table = resolvedCatalog.loadTable(ref.identifier) V2TableReferenceUtils.validateLoadedTable(table, ref) - ref.toRelation(table) + // Create relation with resolved Catalog. + DataSourceV2Relation( + table = table, + output = ref.output, + catalog = Some(resolvedCatalog), + identifier = Some(ref.identifier), + options = ref.options) } private def adaptCachedRelation(cached: LogicalPlan, ref: V2TableReference): LogicalPlan = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala new file mode 100644 index 0000000000000..0e344173d7892 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TransactionalWrite} +import org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.allowInvokingTransformsInAnalyzer +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +/** + * When a transaction is active, converts resolved [[DataSourceV2Relation]] nodes back to + * [[V2TableReference]] placeholders for all relations loaded by a catalog with the same + * name as the transaction catalog. + * + * This forces re-resolution of those relations against the transaction's catalog, which + * intercepts [[TableCatalog#loadTable]] calls to track which tables are read as part of + * the transaction. + */ +class UnresolveTransactionRelations(val catalogManager: CatalogManager) + extends Rule[LogicalPlan] with LookupCatalog { + + override def apply(plan: LogicalPlan): LogicalPlan = + catalogManager.transaction match { + case Some(transaction) => + allowInvokingTransformsInAnalyzer { + plan.transform { + case tw: TransactionalWrite => + unresolveRelations(tw, transaction.catalog) + } + } + case _ => plan + } + + private def unresolveRelations( + plan: LogicalPlan, + catalog: CatalogPlugin): LogicalPlan = { + plan transform { + case r: DataSourceV2Relation if isLoadedFromCatalog(r, catalog) => + V2TableReference.createForTransaction(r) + } + } + + private def isLoadedFromCatalog( + relation: DataSourceV2Relation, + catalog: CatalogPlugin): Boolean = { + relation.catalog.exists(_.name == catalog.name) && relation.identifier.isDefined + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala index 85c36d452b309..f459706a690bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.V2TableReference.Context import org.apache.spark.sql.catalyst.analysis.V2TableReference.TableInfo import org.apache.spark.sql.catalyst.analysis.V2TableReference.TemporaryViewContext +import org.apache.spark.sql.catalyst.analysis.V2TableReference.TransactionContext import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics @@ -37,7 +38,7 @@ import org.apache.spark.sql.connector.catalog.V2TableUtil import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.sql.util.SchemaValidationMode.ALLOW_NEW_TOP_LEVEL_FIELDS +import org.apache.spark.sql.util.SchemaValidationMode.{ALLOW_NEW_TOP_LEVEL_FIELDS, PROHIBIT_CHANGES} import org.apache.spark.util.ArrayImplicits._ /** @@ -84,11 +85,19 @@ private[sql] object V2TableReference { sealed trait Context case class TemporaryViewContext(viewName: Seq[String]) extends Context + /** Context for relations that are re-resolved through a transaction catalog. */ + case object TransactionContext extends Context def createForTempView(relation: DataSourceV2Relation, viewName: Seq[String]): V2TableReference = { create(relation, TemporaryViewContext(viewName)) } + // V2TableReference nodes in the transaction context are produced by + // UnresolveTransactionRelations which unresolves already resolved relations. + def createForTransaction(relation: DataSourceV2Relation): V2TableReference = { + create(relation, TransactionContext) + } + private def create(relation: DataSourceV2Relation, context: Context): V2TableReference = { val ref = V2TableReference( relation.catalog.get.asTableCatalog, @@ -110,11 +119,32 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { ref.context match { case ctx: TemporaryViewContext => validateLoadedTableInTempView(table, ref, ctx) + case TransactionContext => + validateLoadedTableInTransaction(table, ref) case ctx => throw SparkException.internalError(s"Unknown table ref context: ${ctx.getClass.getName}") } } + private def validateLoadedTableInTransaction(table: Table, ref: V2TableReference): Unit = { + // Do not allow schema evolution to pre-analysed dataframes that are later used in + // transactional writes. This is because the entire plans was built based on the original schema + // and any schema change would make the plan structurally invalid. This is inline with the + // semantics of SPARK-54444. + val dataErrors = V2TableUtil.validateCapturedColumns( + table = table, + originCols = ref.info.columns, + mode = PROHIBIT_CHANGES) + if (dataErrors.nonEmpty) { + throw QueryCompilationErrors.columnsChangedAfterAnalysis(ref.name, dataErrors) + } + + val metaErrors = V2TableUtil.validateCapturedMetadataColumns(table, ref.info.metadataColumns) + if (metaErrors.nonEmpty) { + throw QueryCompilationErrors.metadataColumnsChangedAfterAnalysis(ref.name, metaErrors) + } + } + private def validateLoadedTableInTempView( table: Table, ref: V2TableReference, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index c38377582c156..774c783ecf8a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -188,7 +188,11 @@ case class InsertIntoStatement( byName: Boolean = false, replaceCriteriaOpt: Option[InsertReplaceCriteria] = None, withSchemaEvolution: Boolean = false) - extends UnaryParsedStatement { + // Extends TransactionalWrite so that QueryExecution can detect a potential transaction on the + // unresolved logical plan before analysis runs. InsertIntoStatement is shared between V1 and V2 + // inserts, but the LookupCatalog.TransactionalWrite extractor only matches when the target + // catalog implements TransactionalCatalogPlugin, so V1 inserts are never assigned a transaction. + extends UnaryParsedStatement with TransactionalWrite { require(overwrite || !ifPartitionNotExists, "IF NOT EXISTS is only valid in INSERT OVERWRITE") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 500d648d23acb..f9e7774e1448d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -156,7 +156,7 @@ case class AppendData( isByName: Boolean, withSchemaEvolution: Boolean, write: Option[Write] = None, - analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand { + analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand with TransactionalWrite { override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT) override def withNewQuery(newQuery: LogicalPlan): AppendData = copy(query = newQuery) override def withNewTable(newTable: NamedRelation): AppendData = copy(table = newTable) @@ -204,7 +204,7 @@ case class OverwriteByExpression( isByName: Boolean, withSchemaEvolution: Boolean, write: Option[Write] = None, - analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand { + analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand with TransactionalWrite { override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE) override lazy val resolved: Boolean = { @@ -264,7 +264,7 @@ case class OverwritePartitionsDynamic( writeOptions: Map[String, String], isByName: Boolean, withSchemaEvolution: Boolean, - write: Option[Write] = None) extends V2WriteCommand { + write: Option[Write] = None) extends V2WriteCommand with TransactionalWrite { override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE) override def withNewQuery(newQuery: LogicalPlan): OverwritePartitionsDynamic = { @@ -520,8 +520,10 @@ case class WriteDelta( trait V2CreateTableAsSelectPlan extends V2CreateTablePlan with AnalysisOnlyCommand - with CTEInChildren { + with CTEInChildren + with TransactionalWrite { def query: LogicalPlan + override def table: LogicalPlan = name override def withCTEDefs(cteDefs: Seq[CTERelationDef]): LogicalPlan = { withNameAndQuery(newName = name, newQuery = WithCTE(query, cteDefs)) @@ -943,7 +945,8 @@ object DescribeColumn { */ case class DeleteFromTable( table: LogicalPlan, - condition: Expression) extends UnaryCommand with SupportsSubquery { + condition: Expression) + extends UnaryCommand with TransactionalWrite with SupportsSubquery { override def child: LogicalPlan = table override protected def withNewChildInternal(newChild: LogicalPlan): DeleteFromTable = copy(table = newChild) @@ -965,7 +968,8 @@ case class DeleteFromTableWithFilters( case class UpdateTable( table: LogicalPlan, assignments: Seq[Assignment], - condition: Option[Expression]) extends UnaryCommand with SupportsSubquery { + condition: Option[Expression]) + extends UnaryCommand with TransactionalWrite with SupportsSubquery { lazy val aligned: Boolean = AssignmentUtils.aligned(table.output, assignments) @@ -998,8 +1002,13 @@ case class MergeIntoTable( notMatchedActions: Seq[MergeAction], notMatchedBySourceActions: Seq[MergeAction], withSchemaEvolution: Boolean) - extends BinaryCommand with WriteWithSchemaEvolution with SupportsSubquery { + extends BinaryCommand + with WriteWithSchemaEvolution + with SupportsSubquery + with TransactionalWrite { + // Implements SupportsSchemaEvolution.table. + // Implements TransactionalWrite.table. override val table: LogicalPlan = EliminateSubqueryAliases(targetTable) override def withNewTable(newTable: NamedRelation): MergeIntoTable = { @@ -1270,6 +1279,16 @@ case class Assignment(key: Expression, value: Expression) extends Expression newLeft: Expression, newRight: Expression): Assignment = copy(key = newLeft, value = newRight) } +/** + * Marker trait for write operations that participate in a DSv2 transaction. + * + * Implementations are expected to target a DSv2 catalog backed by a + * [[org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin]]. + */ +trait TransactionalWrite extends LogicalPlan { + def table: LogicalPlan +} + /** * The logical plan of the DROP TABLE command. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala new file mode 100644 index 0000000000000..d160aafdea34e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.transactions + +import java.util.UUID + +import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin +import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfoImpl} +import org.apache.spark.util.Utils + +object TransactionUtils { + def commit(transaction: Transaction): Unit = { + Utils.tryWithSafeFinally { + transaction.commit() + } { + transaction.close() + } + } + + def abort(transaction: Transaction): Unit = { + Utils.tryWithSafeFinally { + transaction.abort() + } { + transaction.close() + } + } + + def beginTransaction(catalog: TransactionalCatalogPlugin): Transaction = { + val info = TransactionInfoImpl(id = UUID.randomUUID.toString) + val transaction = catalog.beginTransaction(info) + if (transaction.catalog.name != catalog.name) { + abort(transaction) + throw new IllegalStateException( + s"""Transaction catalog name (${transaction.catalog.name}) + |must match original catalog name (${catalog.name}). + |""".stripMargin) + } + transaction + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index 8663bb65b6a88..152334fdfc600 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, TempVariableManager} import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -39,7 +40,7 @@ import org.apache.spark.sql.internal.SQLConf // need to track current database at all. private[sql] class CatalogManager( - defaultSessionCatalog: CatalogPlugin, + val defaultSessionCatalog: CatalogPlugin, val v1SessionCatalog: SessionCatalog) extends SQLConfHelper with Logging { import CatalogManager.SESSION_CATALOG_NAME import CatalogV2Util._ @@ -57,6 +58,11 @@ class CatalogManager( } } + def transaction: Option[Transaction] = None + + def withTransaction(transaction: Transaction): CatalogManager = + new TransactionAwareCatalogManager(this, transaction) + def isCatalogRegistered(name: String): Boolean = { try { catalog(name) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala index 203cfc23452a8..dd5be45bfc5f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.connector.catalog import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedIdentifier, UnresolvedRelation} +import org.apache.spark.sql.catalyst.plans.logical.TransactionalWrite import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -163,4 +165,17 @@ private[sql] trait LookupCatalog extends Logging { } } } + + object TransactionalWrite { + def unapply(write: TransactionalWrite): Option[TransactionalCatalogPlugin] = { + EliminateSubqueryAliases(write.table) match { + case UnresolvedRelation(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _, _) => + Some(c) + case UnresolvedIdentifier(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _) => + Some(c) + case _ => + None + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala new file mode 100644 index 0000000000000..aaeef4c2dea76 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import org.apache.spark.sql.catalyst.catalog.TempVariableManager +import org.apache.spark.sql.connector.catalog.transactions.Transaction + +/** + * A [[CatalogManager]] decorator that redirects catalog lookups to the transaction's catalog + * instance when names match, ensuring table loads during analysis are scoped to the transaction. + * All mutable state (current catalog, current namespace, loaded catalogs) is delegated to the + * wrapped [[CatalogManager]]. + */ +// TODO: Extracting a CatalogManager trait (so this class can implement it instead of extending +// CatalogManager) would eliminate the inherited mutable state that this decorator doesn't use. +private[sql] class TransactionAwareCatalogManager( + delegate: CatalogManager, + txn: Transaction) + extends CatalogManager(delegate.defaultSessionCatalog, delegate.v1SessionCatalog) { + + override val tempVariableManager: TempVariableManager = delegate.tempVariableManager + + override def transaction: Option[Transaction] = Some(txn) + + override def catalog(name: String): CatalogPlugin = { + val resolved = delegate.catalog(name) + if (txn.catalog.name() == resolved.name()) txn.catalog else resolved + } + + override def currentCatalog: CatalogPlugin = { + val c = delegate.currentCatalog + if (txn.catalog.name() == c.name()) txn.catalog else c + } + + override def currentNamespace: Array[String] = delegate.currentNamespace + + override def setCurrentNamespace(namespace: Array[String]): Unit = + delegate.setCurrentNamespace(namespace) + + override def setCurrentCatalog(catalogName: String): Unit = + delegate.setCurrentCatalog(catalogName) + + override def listCatalogs(pattern: Option[String]): Seq[String] = + delegate.listCatalogs(pattern) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala new file mode 100644 index 0000000000000..4cb53da0a59e2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.transactions + +case class TransactionInfoImpl(id: String) extends TransactionInfo diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala new file mode 100644 index 0000000000000..d409316e667b1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.transactions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, TransactionalCatalogPlugin} +import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class TransactionUtilsSuite extends SparkFunSuite { + val testCatalogName = "test_catalog" + + // --- Helpers --------------------------------------------------------------- + private def mockCatalog(catalogName: String): CatalogPlugin = new CatalogPlugin { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = () + override def name(): String = catalogName + } + + private val emptyFunction = () => () + private class TestTransaction( + catalogName: String, + onCommit: () => Unit = emptyFunction, + onAbort: () => Unit = emptyFunction, + onClose: () => Unit = emptyFunction) extends Transaction { + var committed = false + var aborted = false + var closed = false + + override def catalog(): CatalogPlugin = mockCatalog(catalogName) + override def commit(): Unit = { committed = true; onCommit() } + override def abort(): Unit = { aborted = true; onAbort() } + override def close(): Unit = { closed = true; onClose() } + } + + private def mockTransactionalCatalog( + catalogName: String, + txnCatalogName: String = null): TransactionalCatalogPlugin = { + val resolvedTxnCatalogName = Option(txnCatalogName).getOrElse(catalogName) + new TransactionalCatalogPlugin { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = () + override def name(): String = catalogName + override def beginTransaction(info: TransactionInfo): Transaction = + new TestTransaction(resolvedTxnCatalogName) + } + } + + // --- Commit ---------------------------------------------------------------- + test("commit: calls commit then close") { + val txn = new TestTransaction(testCatalogName) + TransactionUtils.commit(txn) + assert(txn.committed) + assert(txn.closed) + } + + test("commit: close is called even if commit fails") { + val txn = new TestTransaction( + testCatalogName, onCommit = () => throw new RuntimeException("commit failed")) + intercept[RuntimeException] { TransactionUtils.commit(txn) } + assert(txn.closed) + } + + // --- Abort ----------------------------------------------------------------- + test("abort: calls abort then close") { + val txn = new TestTransaction(testCatalogName) + TransactionUtils.abort(txn) + assert(txn.aborted) + assert(txn.closed) + } + + test("abort: close is called even if abort fails") { + val txn = new TestTransaction(testCatalogName, + onAbort = () => throw new RuntimeException("abort failed")) + intercept[RuntimeException] { TransactionUtils.abort(txn) } + assert(txn.closed) + } + + // --- Begin Transaction ----------------------------------------------------- + test("beginTransaction: returns transaction when catalog names match") { + val catalog = mockTransactionalCatalog(testCatalogName) + val txn = TransactionUtils.beginTransaction(catalog) + assert(txn.catalog().name() == testCatalogName) + } + + test("beginTransaction: fails when transaction catalog name does not match") { + val catalog = mockTransactionalCatalog(catalogName = testCatalogName, txnCatalogName = "other") + val e = intercept[IllegalStateException] { + TransactionUtils.beginTransaction(catalog) + } + assert(e.getMessage.contains("other")) + assert(e.getMessage.contains(testCatalogName)) + } + + test("beginTransaction: aborts and closes transaction on catalog name mismatch") { + var aborted = false + var closed = false + val catalog = new TransactionalCatalogPlugin { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = () + override def name(): String = testCatalogName + override def beginTransaction(info: TransactionInfo): Transaction = + new TestTransaction( + "other", + onAbort = () => { aborted = true }, + onClose = () => { closed = true }) + } + intercept[IllegalStateException] { TransactionUtils.beginTransaction(catalog) } + assert(aborted) + assert(closed) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 8791677999810..af46ac006f3c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -76,6 +76,10 @@ abstract class InMemoryBaseTable( override def columns(): Array[Column] = tableColumns + private[catalog] def updateColumns(newColumns: Array[Column]): Unit = { + tableColumns = newColumns + } + override def version(): String = tableVersion.toString def setVersion(version: String): Unit = { @@ -94,6 +98,8 @@ abstract class InMemoryBaseTable( validatedTableVersion = version } + protected def recordScanEvent(filters: Array[Filter]): Unit = {} + protected object PartitionKeyColumn extends MetadataColumn { override def name: String = "_partition" override def dataType: DataType = StringType @@ -455,6 +461,7 @@ abstract class InMemoryBaseTable( if (evaluableFilters.nonEmpty) { scan.filter(evaluableFilters) } + recordScanEvent(_pushedFilters) scan } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index bbb9041bab37c..4e5e1e7c8c6e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -17,11 +17,31 @@ package org.apache.spark.sql.connector.catalog +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} +import org.apache.spark.sql.types.StructType -class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog { +class InMemoryRowLevelOperationTableCatalog + extends InMemoryTableCatalog + with TransactionalCatalogPlugin { import CatalogV2Implicits._ + // The current active transaction. + var transaction: Txn = _ + // The last completed transaction. + var lastTransaction: Txn = _ + // All transactions in order (committed and aborted), allowing per-statement + // validation in SQL scripting tests. + val seenTransactions: ArrayBuffer[Txn] = new ArrayBuffer[Txn]() + + override def beginTransaction(info: TransactionInfo): Transaction = { + assert(transaction == null || transaction.currentState != Active) + this.transaction = new Txn(new TxnTableCatalog(this)) + transaction + } + override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { if (tables.containsKey(ident)) { throw new TableAlreadyExistsException(ident.asMultipartIdentifier) @@ -41,11 +61,7 @@ class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog { override def alterTable(ident: Identifier, changes: TableChange*): Table = { val table = loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes) - val schema = CatalogV2Util.applySchemaChanges( - table.schema, - changes, - tableProvider = Some("in-memory"), - statementType = "ALTER TABLE") + val schema = computeAlterTableSchema(table.schema, changes.toSeq) val partitioning = CatalogV2Util.applyClusterByChanges(table.partitioning, schema, changes) val constraints = CatalogV2Util.collectConstraintChanges(table, changes) @@ -66,6 +82,16 @@ class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog { newTable } + + /** + * Computes the schema that would result from applying `changes` to `currentSchema`. + * Can be overridden by subclasses to simulate catalogs that selectively ignore changes + * (e.g. [[PartialSchemaEvolutionCatalog]]). + */ + def computeAlterTableSchema(currentSchema: StructType, changes: Seq[TableChange]): StructType = { + CatalogV2Util.applySchemaChanges( + currentSchema, changes, tableProvider = Some("in-memory"), statementType = "ALTER TABLE") + } } /** @@ -84,9 +110,10 @@ class PartialSchemaEvolutionCatalog extends InMemoryRowLevelOperationTableCatalo case _ => false } val properties = CatalogV2Util.applyPropertiesChanges(table.properties, propertyChanges) + val schema = computeAlterTableSchema(table.schema, changes.toSeq) val newTable = new InMemoryRowLevelOperationTable( name = table.name, - schema = table.schema, + schema = schema, partitioning = table.partitioning, properties = properties, constraints = table.constraints) @@ -94,4 +121,9 @@ class PartialSchemaEvolutionCatalog extends InMemoryRowLevelOperationTableCatalo tables.put(ident, newTable) newTable } + + // Ignores all schema changes and returns the current schema unchanged. + override def computeAlterTableSchema( + currentSchema: StructType, + changes: Seq[TableChange]): StructType = currentSchema } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index d5738475031dc..15ed4136dbda8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ /** @@ -215,6 +216,15 @@ class InMemoryTable( object InMemoryTable { + // Convert UTF8String to string to make sure equality checks between filters and partitions + // work correctly. + private def valuesEqual(filterValue: Any, partitionValue: Any): Boolean = + (filterValue, partitionValue) match { + case (s: String, u: UTF8String) => u.toString == s + case (u: UTF8String, s: String) => u.toString == s + case _ => filterValue == partitionValue + } + def filtersToKeys( keys: Iterable[Seq[Any]], partitionNames: Seq[String], @@ -222,7 +232,7 @@ object InMemoryTable { keys.filter { partValues => filters.flatMap(splitAnd).forall { case EqualTo(attr, value) => - value == InMemoryBaseTable.extractValue(attr, partitionNames, partValues) + valuesEqual(value, InMemoryBaseTable.extractValue(attr, partitionNames, partValues)) case EqualNullSafe(attr, value) => val attrVal = InMemoryBaseTable.extractValue(attr, partitionNames, partValues) if (attrVal == null && value == null) { @@ -230,7 +240,7 @@ object InMemoryTable { } else if (attrVal == null || value == null) { false } else { - value == attrVal + valuesEqual(value, attrVal) } case IsNull(attr) => null == InMemoryBaseTable.extractValue(attr, partitionNames, partValues) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala new file mode 100644 index 0000000000000..232b623174996 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.connector.catalog.transactions.Transaction +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +sealed trait TransactionState +case object Active extends TransactionState +case object Committed extends TransactionState +case object Aborted extends TransactionState + +class Txn(override val catalog: TxnTableCatalog) extends Transaction { + + private[this] var state: TransactionState = Active + private[this] var closed: Boolean = false + + def currentState: TransactionState = state + + def isClosed: Boolean = closed + + override def commit(): Unit = { + if (closed) throw new IllegalStateException("Can't commit, already closed") + if (state == Aborted) throw new IllegalStateException("Can't commit, already aborted") + catalog.commit() + this.state = Committed + } + + // This is idempotent since nested QEs can cause multiple aborts. + override def abort(): Unit = { + if (state == Committed || state == Aborted) return + this.state = Aborted + } + + // This is idempotent since nested QEs can cause multiple aborts. + override def close(): Unit = { + if (!closed) { + catalog.clearActiveTransaction() + this.closed = true + } + } +} + +// A special table used in row-level operation transactions. It inherits data +// from the base table upon construction and propagates staged transaction state +// back after an explicit commit. +// Note, the in-memory data store does not handle concurrency at the moment. The assumes that the +// underlying delegate table cannot change from concurrent transactions. Data sources need to +// implement isolation semantics and make sure they are enforced. +class TxnTable(val delegate: InMemoryRowLevelOperationTable, schema: StructType) + extends InMemoryRowLevelOperationTable( + delegate.name, + schema, + delegate.partitioning, + delegate.properties, + delegate.constraints) { + + alterTableWithData(delegate.data, schema) + + // Keep initial version to detect any changes during the transaction. + private val initialVersion: String = version() + + // A tracker of filters used in each scan. + val scanEvents = new ArrayBuffer[Array[Filter]]() + + // Record scan events. This is invoked when building a scan for the particular table. + override protected def recordScanEvent(filters: Array[Filter]): Unit = { + scanEvents += filters + } + + // Perform commit if there are any changes. This push metadata and data changes to the + // delegate table. + def commit(): Unit = { + if (version() != initialVersion) { + delegate.dataMap.clear() + delegate.updateColumns(columns()) // Evolve schema if needed. + delegate.alterTableWithData(data, schema) + delegate.replacedPartitions = replacedPartitions + delegate.lastWriteInfo = lastWriteInfo + delegate.lastWriteLog = lastWriteLog + delegate.commits ++= commits + delegate.increaseVersion() + } + } +} + +// A special table catalog used in row-level operation transactions. The lifecycle of this catalog +// is tied to the transaction. A new catalog instance is created at the beginning of a transaction +// and discarded at the end. The catalog is responsible for pinning all tables involved in the +// transaction. Table changes are initially staged in memory and propagated only after an explicit +// commit. +class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends TableCatalog { + + private val tables: util.Map[Identifier, TxnTable] = new ConcurrentHashMap[Identifier, TxnTable]() + + override def name: String = delegate.name + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} + + override def listTables(namespace: Array[String]): Array[Identifier] = { + throw new UnsupportedOperationException() + } + + // This is where the table pinning logic should occur. In this implementation, a tables is loaded + // (pinned) the first time is accessed. All subsequent accesses should return the same pinned + // table. + override def loadTable(ident: Identifier): Table = { + tables.computeIfAbsent(ident, _ => { + val table = delegate.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] + new TxnTable(table, table.schema()) + }) + } + + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + // AlterTable may be called by ResolveSchemaEvolution when schema evolution is enabled. Thus, + // it needs to be transactional. The schema changes are only propagated to the delegate at + // commit time. + // + // We delegate schema computation to the underlying catalog so that catalogs with special + // handling (e.g. PartialSchemaEvolutionCatalog) have the same behaviour inside a + // transaction. + val txnTable = tables.get(ident) + val schema = delegate.computeAlterTableSchema(txnTable.schema, changes.toSeq) + + if (schema.fields.isEmpty) { + throw new IllegalArgumentException(s"Cannot drop all fields") + } + + // TODO: We need to pass all tracked predicates to the new TXN table. + val newTxnTable = new TxnTable(txnTable.delegate, schema) + tables.put(ident, newTxnTable) + newTxnTable + } + + // TODO: Currently not transactional. Should be revised when Atomic CTAS/RTAS is implemented. + override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { + delegate.createTable(ident, tableInfo) + loadTable(ident) + } + + // TODO: Currently not transactional. Should be revised when Atomic CTAS/RTAS is implemented. + override def dropTable(ident: Identifier): Boolean = { + tables.remove(ident) + delegate.dropTable(ident) + } + + override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { + throw new UnsupportedOperationException() + } + + // Invoke commit for all tables participated in the transaction. If a table is read-only + // this is a no-op. + def commit(): Unit = { + tables.values.forEach(table => table.commit()) + } + + // Clear transaction context. + def clearActiveTransaction(): Unit = { + val txn = delegate.transaction + delegate.lastTransaction = txn + delegate.seenTransactions += txn + delegate.transaction = null + } + + override def equals(obj: Any): Boolean = { + obj match { + case that: CatalogPlugin => this.name == that.name + case _ => false + } + } + + override def hashCode(): Int = name.hashCode() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index f08b561d6ef9a..e07d0147205a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -27,23 +27,25 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.apache.spark.SparkException -import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} -import org.apache.spark.sql.catalyst.analysis.{LazyExpression, NameParameterizedQuery, UnsupportedOperationChecker} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, LazyExpression, NameParameterizedQuery, UnsupportedOperationChecker} import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union, WithCTE} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union, UnresolvedWith, WithCTE} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} +import org.apache.spark.sql.catalyst.transactions.TransactionUtils import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.connector.catalog.LookupCatalog +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.execution.SQLExecution.EXECUTION_ROOT_ID_KEY import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan} import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan} -import org.apache.spark.sql.execution.datasources.v2.V2TableRefreshUtil +import org.apache.spark.sql.execution.datasources.v2.{TransactionalExec, V2TableRefreshUtil} import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery @@ -69,7 +71,8 @@ class QueryExecution( val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL, val shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None, val refreshPhaseEnabled: Boolean = true, - val queryId: UUID = UUIDv7Generator.generate()) extends Logging { + val queryId: UUID = UUIDv7Generator.generate(), + val analyzerOpt: Option[Analyzer] = None) extends LookupCatalog { val id: Long = QueryExecution.nextExecutionId @@ -79,6 +82,8 @@ class QueryExecution( // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner + protected val catalogManager = sparkSession.sessionState.catalogManager + /** * Check whether the query represented by this QueryExecution is a SQL script. * @return True if the query is a SQL script, False otherwise. @@ -90,6 +95,46 @@ class QueryExecution( logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyExpression]))) } + + // 1. At the pre-Analyzed plan we look for nodes that implement the TransactionalWrite trait. + // When a plan contains such a node we initiate a transaction. Note, we should never start + // a transaction for operations that are not executed, e.g. EXPLAIN. + // 2. Create an analyzer clone with a transaction aware Catalog Manager. The latter is the + // narrow waist of all catalog accesses, and it is also the transaction context carrier. + // This is then passed to all rules during analysis that need to check the catalog. Rules + // that are specifically interested in transactionality can access the transaction directly + // from the Catalog Manager. The transaction catalog, is potentially the place where connectors + // should keep state about the reads (tables+predicates) that occurred during the transaction. + // 3. The analyzer instance is passed to nested Query Execution instances. These need to respect + // the open transaction instead of creating their own. + private lazy val transactionOpt: Option[Transaction] = + // Always inherit an active transaction from the outer analyzer, regardless of mode. + analyzerOpt.flatMap(_.catalogManager.transaction).orElse { + // Only begin a new transaction for outer QEs that lead to execution. + if (mode != CommandExecutionMode.SKIP) { + val catalog = logical match { + case UnresolvedWith(TransactionalWrite(c), _, _) => Some(c) + case TransactionalWrite(c) => Some(c) + case _ => None + } + catalog.map(TransactionUtils.beginTransaction) + } else { + None + } + } + + // Per-query analyzer: uses a transaction-aware CatalogManager when a transaction is active, + // so that all catalog lookups and rule applications during analysis see the correct state + // without relying on thread-local context. + private lazy val analyzer: Analyzer = analyzerOpt.getOrElse { + transactionOpt match { + case Some(txn) => + sparkSession.sessionState.analyzer.withCatalogManager(catalogManager.withTransaction(txn)) + case None => + sparkSession.sessionState.analyzer + } + } + def assertAnalyzed(): Unit = { try { analyzed @@ -102,7 +147,7 @@ class QueryExecution( } } - def assertSupported(): Unit = { + def assertSupported(): Unit = withAbortTransactionOnFailure { if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { UnsupportedOperationChecker.checkForBatch(analyzed) } @@ -141,7 +186,7 @@ class QueryExecution( try { val plan = executePhase(QueryPlanningTracker.ANALYSIS) { // We can't clone `logical` here, which will reset the `_analyzed` flag. - sparkSession.sessionState.analyzer.executeAndCheck(sqlScriptExecuted, tracker) + analyzer.executeAndCheck(sqlScriptExecuted, tracker) } tracker.setAnalyzed(plan) plan @@ -152,7 +197,9 @@ class QueryExecution( } } - def analyzed: LogicalPlan = lazyAnalyzed.get + def analyzed: LogicalPlan = withAbortTransactionOnFailure { + lazyAnalyzed.get + } private val lazyCommandExecuted = LazyTry { mode match { @@ -162,7 +209,9 @@ class QueryExecution( } } - def commandExecuted: LogicalPlan = lazyCommandExecuted.get + def commandExecuted: LogicalPlan = withAbortTransactionOnFailure { + lazyCommandExecuted.get + } private def commandExecutionName(command: Command): String = command match { case _: CreateTableAsSelect => "create" @@ -184,7 +233,7 @@ class QueryExecution( // for eagerly executed commands we mark this place as beginning of execution. tracker.setReadyForExecution() val (qe, result) = QueryExecution.runCommand( - sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode)) + sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode), Some(analyzer)) CommandResult( qe.analyzed.output, qe.commandExecuted, @@ -222,19 +271,31 @@ class QueryExecution( } // The plan that has been normalized by custom rules, so that it's more likely to hit cache. - def normalized: LogicalPlan = lazyNormalized.get + def normalized: LogicalPlan = withAbortTransactionOnFailure { + lazyNormalized.get + } private val lazyWithCachedData = LazyTry { sparkSession.withActive { assertAnalyzed() assertSupported() - // clone the plan to avoid sharing the plan instance between different stages like analyzing, - // optimizing and planning. - sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) + + // During a transaction, skip cache substitution. This is to avoid replacing relations + // loaded by the transactional catalog with potentially stale relations cached before + // the transaction was active. + if (transactionOpt.isDefined) { + normalized + } else { + // Clone the plan to avoid sharing the plan instance between different stages like + // analyzing, optimizing and planning. + sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) + } } } - def withCachedData: LogicalPlan = lazyWithCachedData.get + def withCachedData: LogicalPlan = withAbortTransactionOnFailure { + lazyWithCachedData.get + } def assertCommandExecuted(): Unit = commandExecuted @@ -256,7 +317,9 @@ class QueryExecution( } } - def optimizedPlan: LogicalPlan = lazyOptimizedPlan.get + def optimizedPlan: LogicalPlan = withAbortTransactionOnFailure { + lazyOptimizedPlan.get + } def assertOptimized(): Unit = optimizedPlan @@ -264,14 +327,17 @@ class QueryExecution( // We need to materialize the optimizedPlan here because sparkPlan is also tracked under // the planning phase assertOptimized() - executePhase(QueryPlanningTracker.PLANNING) { + val plan = executePhase(QueryPlanningTracker.PLANNING) { // Clone the logical plan here, in case the planner rules change the states of the logical // plan. QueryExecution.createSparkPlan(planner, optimizedPlan.clone()) } + attachTransaction(plan) } - def sparkPlan: SparkPlan = lazySparkPlan.get + def sparkPlan: SparkPlan = withAbortTransactionOnFailure { + lazySparkPlan.get + } def assertSparkPlanPrepared(): Unit = sparkPlan @@ -292,7 +358,9 @@ class QueryExecution( // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - def executedPlan: SparkPlan = lazyExecutedPlan.get + def executedPlan: SparkPlan = withAbortTransactionOnFailure { + lazyExecutedPlan.get + } def assertExecutedPlanPrepared(): Unit = executedPlan @@ -310,7 +378,9 @@ class QueryExecution( * Given QueryExecution is not a public class, end users are discouraged to use this: please * use `Dataset.rdd` instead where conversion will be applied. */ - def toRdd: RDD[InternalRow] = lazyToRdd.get + def toRdd: RDD[InternalRow] = withAbortTransactionOnFailure { + lazyToRdd.get + } private val observedMetricsLock = new Object @@ -512,6 +582,26 @@ class QueryExecution( } } + /** + * Runs the given block, aborting the active transaction if an exception is thrown. + * If no transaction is active, the block is executed as-is. + */ + private def withAbortTransactionOnFailure[T](block: => T): T = transactionOpt match { + case Some(transaction) => + try block + catch { case e: Throwable => TransactionUtils.abort(transaction); throw e } + case None => block + } + + + /** Attaches a transaction to the given SparkPlan to the transactional execution nodes. */ + private def attachTransaction(plan: SparkPlan): SparkPlan = transactionOpt match { + case Some(txn) => plan.transformDown { + case w: TransactionalExec => w.withTransaction(Some(txn)) + } + case None => plan + } + /** A special namespace for commands that can be used to debug query execution. */ // scalastyle:off object debug { @@ -796,14 +886,16 @@ object QueryExecution { name: String, refreshPhaseEnabled: Boolean = true, mode: CommandExecutionMode.Value = CommandExecutionMode.SKIP, - shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None) + shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None, + analyzerOpt: Option[Analyzer] = None) : (QueryExecution, Array[InternalRow]) = { val qe = new QueryExecution( sparkSession, command, mode = mode, shuffleCleanupModeOpt = shuffleCleanupModeOpt, - refreshPhaseEnabled = refreshPhaseEnabled) + refreshPhaseEnabled = refreshPhaseEnabled, + analyzerOpt = analyzerOpt) val result = QueryExecution.withInternalError(s"Executed $name failed.") { SQLExecution.withNewExecutionId(qe, Some(name)) { qe.executedPlan.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala index 8d5ee6038e80f..c6b1bae89b156 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala @@ -19,16 +19,23 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.transactions.TransactionUtils import org.apache.spark.sql.connector.catalog.SupportsDeleteV2 +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.connector.expressions.filter.Predicate case class DeleteFromTableExec( table: SupportsDeleteV2, condition: Array[Predicate], - refreshCache: () => Unit) extends LeafV2CommandExec { + refreshCache: () => Unit, + transaction: Option[Transaction] = None) extends LeafV2CommandExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): DeleteFromTableExec = + copy(transaction = txn) override protected def run(): Seq[InternalRow] = { table.deleteWhere(condition) + transaction.foreach(TransactionUtils.commit) refreshCache() Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 804e694f92e9c..07876887788be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -26,9 +26,11 @@ import org.apache.spark.sql.catalyst.{InternalRow, ProjectingInternalRow} import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, TableSpec, UnaryNode} +import org.apache.spark.sql.catalyst.transactions.TransactionUtils import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, ReplaceDataProjections, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege} +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, Write, WriterCommitMessage, WriteSummary} @@ -73,7 +75,12 @@ case class CreateTableAsSelectExec( query: LogicalPlan, tableSpec: TableSpec, writeOptions: Map[String, String], - ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec { + ifNotExists: Boolean, + transaction: Option[Transaction] = None) + extends V2CreateTableAsSelectBaseExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): CreateTableAsSelectExec = + copy(transaction = txn) val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -91,7 +98,9 @@ case class CreateTableAsSelectExec( .build() val table = Option(catalog.createTable(ident, tableInfo)) .getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - writeToTable(catalog, table, writeOptions, ident, query, overwrite = false) + val result = writeToTable(catalog, table, writeOptions, ident, query, overwrite = false) + transaction.foreach(TransactionUtils.commit) + result } } @@ -111,7 +120,8 @@ case class AtomicCreateTableAsSelectExec( query: LogicalPlan, tableSpec: TableSpec, writeOptions: Map[String, String], - ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec { + ifNotExists: Boolean) + extends V2CreateTableAsSelectBaseExec { val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -154,8 +164,12 @@ case class ReplaceTableAsSelectExec( tableSpec: TableSpec, writeOptions: Map[String, String], orCreate: Boolean, - invalidateCache: (TableCatalog, Identifier) => Unit) - extends V2CreateTableAsSelectBaseExec { + invalidateCache: (TableCatalog, Identifier) => Unit, + transaction: Option[Transaction] = None) + extends V2CreateTableAsSelectBaseExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): ReplaceTableAsSelectExec = + copy(transaction = txn) val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -191,9 +205,11 @@ case class ReplaceTableAsSelectExec( .build() val table = Option(catalog.createTable(ident, tableInfo)) .getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - writeToTable( + val result = writeToTable( catalog, table, writeOptions, ident, refreshedQuery, overwrite = true, refreshPhaseEnabled = false) + transaction.foreach(TransactionUtils.commit) + result } } @@ -271,7 +287,9 @@ case class AtomicReplaceTableAsSelectExec( case class AppendDataExec( query: SparkPlan, refreshCache: () => Unit, - write: Write) extends V2ExistingTableWriteExec { + write: Write, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { + override def withTransaction(txn: Option[Transaction]): AppendDataExec = copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): AppendDataExec = copy(query = newChild) } @@ -289,7 +307,10 @@ case class AppendDataExec( case class OverwriteByExpressionExec( query: SparkPlan, refreshCache: () => Unit, - write: Write) extends V2ExistingTableWriteExec { + write: Write, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { + override def withTransaction(txn: Option[Transaction]): OverwriteByExpressionExec = + copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): OverwriteByExpressionExec = copy(query = newChild) } @@ -306,7 +327,10 @@ case class OverwriteByExpressionExec( case class OverwritePartitionsDynamicExec( query: SparkPlan, refreshCache: () => Unit, - write: Write) extends V2ExistingTableWriteExec { + write: Write, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { + override def withTransaction(txn: Option[Transaction]): OverwritePartitionsDynamicExec = + copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): OverwritePartitionsDynamicExec = copy(query = newChild) } @@ -318,7 +342,8 @@ case class ReplaceDataExec( query: SparkPlan, refreshCache: () => Unit, projections: ReplaceDataProjections, - write: Write) extends V2ExistingTableWriteExec { + write: Write, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { override def writingTask: WritingSparkTask[_] = { projections match { @@ -329,6 +354,7 @@ case class ReplaceDataExec( } } + override def withTransaction(txn: Option[Transaction]): ReplaceDataExec = copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): ReplaceDataExec = { copy(query = newChild) } @@ -341,7 +367,8 @@ case class WriteDeltaExec( query: SparkPlan, refreshCache: () => Unit, projections: WriteDeltaProjections, - write: DeltaWrite) extends V2ExistingTableWriteExec { + write: DeltaWrite, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { override lazy val writingTask: WritingSparkTask[_] = { if (projections.metadataProjection.isDefined) { @@ -351,6 +378,7 @@ case class WriteDeltaExec( } } + override def withTransaction(txn: Option[Transaction]): WriteDeltaExec = copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): WriteDeltaExec = { copy(query = newChild) } @@ -378,7 +406,16 @@ case class WriteToDataSourceV2Exec( copy(query = newChild) } -trait V2ExistingTableWriteExec extends V2TableWriteExec { +/** + * Trait for physical plan nodes that write to an existing table as part of a transaction. + * The [[transaction]] is injected post-planning by [[QueryExecution]]. + */ +trait TransactionalExec extends SparkPlan { + def transaction: Option[Transaction] + def withTransaction(txn: Option[Transaction]): SparkPlan +} + +trait V2ExistingTableWriteExec extends V2TableWriteExec with TransactionalExec { def refreshCache: () => Unit def write: Write @@ -395,6 +432,7 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec { } finally { postDriverMetrics() } + transaction.foreach(TransactionUtils.commit) refreshCache() writtenRows } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala new file mode 100644 index 0000000000000..aef9c65550fc4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala @@ -0,0 +1,500 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.{Aborted, Committed} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode +import org.apache.spark.sql.sources + +class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { + + test("writeTo append with transactional checks") { + // create table with initial data + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // create a source on top of itself that will be fully resolved and analyzed + val sourceDF = spark.table(tableNameAsString) + .where("pk == 1") + .select(col("pk") + 10 as "pk", col("salary"), col("dep")) + sourceDF.queryExecution.assertAnalyzed() + + // append data using the DataFrame API + val (txn, txnTables) = executeTransaction { + sourceDF.writeTo(tableNameAsString).append() + } + + // check txn was properly committed and closed + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(txnTables.size === 1) + assert(table.version() === "2") + + // check the source scan was tracked via the transaction catalog + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("pk", 1) => true + case _ => false + }) + + // check data was appended correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(11, 100, "hr"))) // appended + } + + test("SQL INSERT INTO with transactional checks") { + // create table with initial data + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // SQL INSERT INTO using VALUES + val (txn, txnTables) = executeTransaction { + sql(s"INSERT INTO $tableNameAsString VALUES (3, 300, 'hr'), (4, 400, 'finance')") + } + + // check txn was properly committed and closed + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(table.version() === "2") + + // VALUES literal - No catalog tables were scanned + assert(txnTables.isEmpty) + + // check data was inserted correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 300, "hr"), + Row(4, 400, "finance"))) + } + + for (isDynamic <- Seq(false, true)) + test(s"SQL INSERT OVERWRITE with transactional checks - isDynamic: $isDynamic") { + // create table with initial data; table is partitioned by dep + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val insertOverwrite = if (isDynamic) { + // OverwritePartitionsDynamic + s"""INSERT OVERWRITE $tableNameAsString + |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } else { + // OverwriteByExpression + s"""INSERT OVERWRITE $tableNameAsString + |PARTITION (dep = 'hr') + |SELECT pk + 10, salary FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } + + val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC + val (txn, txnTables) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { + executeTransaction { sql(insertOverwrite) } + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(table.version() === "2") + + // the SELECT reads from the target table once with a dep='hr' filter + assert(txnTables.size == 1) + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size == 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(11, 100, "hr"), // overwritten + Row(13, 300, "hr"))) // overwritten + } + + test("writeTo overwrite with transactional checks") { + // create table with initial data; table is partitioned by dep + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // overwrite using a condition that covers the hr partition -> OverwriteByExpression + val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))). + toDF("pk", "salary", "dep") + + val (txn, txnTables) = executeTransaction { + sourceDF.writeTo(tableNameAsString).overwrite(col("dep") === "hr") + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(table.version() === "2") + + // literal DataFrame source - no catalog tables were scanned + assert(txnTables.isEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(11, 999, "hr"), // overwrote hr partition + Row(12, 888, "hr"))) // overwrote hr partition + } + + test("writeTo overwritePartitions with transactional checks") { + // create table with initial data; table is partitioned by dep + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // overwrite partitions dynamically -> OverwritePartitionsDynamic + val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))). + toDF("pk", "salary", "dep") + + val (txn, txnTables) = executeTransaction { + sourceDF.writeTo(tableNameAsString).overwritePartitions() + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(table.version() === "2") + + // literal DataFrame source - no catalog tables were scanned + assert(txnTables.isEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(11, 999, "hr"), // overwrote hr partition + Row(12, 888, "hr"))) // overwrote hr partition + } + + test("SQL INSERT INTO SELECT with transactional checks") { + // create table with initial data + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // SQL INSERT INTO using SELECT from the same table (self-insert) + val (txn, txnTables) = executeTransaction { + sql(s"""INSERT INTO $tableNameAsString + |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(table.version() === "2") + + // the SELECT reads from the target table once with a dep='hr' filter + assert(txnTables.size === 1) + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + // check data was inserted correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 300, "hr"), + Row(11, 100, "hr"), // inserted from pk=1 + Row(13, 300, "hr"))) // inserted from pk=3 + } + + test("SQL INSERT INTO SELECT with subquery on source table and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 500, 'hr'), (3, 600, 'software')") + + // INSERT using a subquery that reads from the target to filter source rows + // both tables are scanned through the transaction catalog + val (txn, txnTables) = executeTransaction { + sql( + s"""INSERT INTO $tableNameAsString + |SELECT pk + 10, salary, dep FROM $sourceNameAsString + |WHERE pk IN (SELECT pk FROM $tableNameAsString WHERE dep = 'hr') + |""".stripMargin) + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(txnTables.size === 2) + assert(table.version() === "2") + + // target was scanned via the transaction catalog (IN subquery) once with dep='hr' filter + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + // source was scanned via the transaction catalog exactly once (no filter) + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(11, 500, "hr"))) // inserted: source pk=1 matched target hr row + } + + test("SQL INSERT INTO SELECT with CTE and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 500, 'hr'), (3, 600, 'software')") + + // CTE reads from target; INSERT selects from source filtered by the CTE result + // both tables are scanned through the transaction catalog + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH hr_pks AS (SELECT pk FROM $tableNameAsString WHERE dep = 'hr') + |INSERT INTO $tableNameAsString + |SELECT pk + 10, salary, dep FROM $sourceNameAsString + |WHERE pk IN (SELECT pk FROM hr_pks) + |""".stripMargin) + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(txnTables.size === 2) + assert(table.version() === "2") + + // target was scanned via the transaction catalog (CTE) once with dep='hr' filter + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + // source was scanned via the transaction catalog exactly once (no filter) + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(11, 500, "hr"))) // inserted: source pk=1 matched target hr row via CTE + } + + test("SQL INSERT with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val e = intercept[AnalysisException] { + sql(s"INSERT INTO $tableNameAsString SELECT nonexistent_col FROM $tableNameAsString") + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) + } + + for (isDynamic <- Seq(false, true)) + test(s"SQL INSERT OVERWRITE with analysis failure and transactional checks" + + s"isDynamic: $isDynamic") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val insertOverwrite = if (isDynamic) { + s"""INSERT OVERWRITE $tableNameAsString + |SELECT nonexistent_col, salary, dep FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } else { + s"""INSERT OVERWRITE $tableNameAsString + |PARTITION (dep = 'hr') + |SELECT nonexistent_col FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } + + val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC + val e = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { + intercept[AnalysisException] { sql(insertOverwrite) } + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("EXPLAIN INSERT SQL with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"EXPLAIN INSERT INTO $tableNameAsString VALUES (3, 300, 'hr')") + + // EXPLAIN should not start a transaction + assert(catalog.transaction === null) + + // INSERT was not executed; data is unchanged + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) + } + + test("SQL INSERT WITH SCHEMA EVOLUTION adds new column with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql( + s"""CREATE TABLE $sourceNameAsString + |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin) + sql(s"INSERT INTO $sourceNameAsString VALUES (3, 300, 'hr', true), (4, 400, 'software', false)") + + val (txn, txnTables) = executeTransaction { + sql(s"INSERT WITH SCHEMA EVOLUTION INTO $tableNameAsString SELECT * FROM $sourceNameAsString") + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // the new column must be visible in the committed delegate's schema + assert(table.schema.fieldNames.toSeq === Seq("pk", "salary", "dep", "active")) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), // pre-existing rows: active is null + Row(2, 200, "software", null), + Row(3, 300, "hr", true), // inserted with active + Row(4, 400, "software", false))) + } + + for (isDynamic <- Seq(false, true)) + test(s"SQL INSERT OVERWRITE WITH SCHEMA EVOLUTION adds new column with transactional checks " + + s"isDynamic: $isDynamic") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql( + s"""CREATE TABLE $sourceNameAsString + |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin) + sql(s"INSERT INTO $sourceNameAsString VALUES (11, 999, 'hr', true), (12, 888, 'hr', false)") + + val insertOverwrite = if (isDynamic) { + s"""INSERT WITH SCHEMA EVOLUTION OVERWRITE TABLE $tableNameAsString + |SELECT * FROM $sourceNameAsString + |""".stripMargin + } else { + s"""INSERT WITH SCHEMA EVOLUTION OVERWRITE TABLE $tableNameAsString + |PARTITION (dep = 'hr') + |SELECT pk, salary, active FROM $sourceNameAsString + |""".stripMargin + } + + val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC + val (txn, _) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { + executeTransaction { sql(insertOverwrite) } + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // the new column must be visible in the committed delegate's schema + assert(table.schema.fieldNames.contains("active")) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software", null), // unchanged (different partition) + Row(11, 999, "hr", true), // overwrote hr partition + Row(12, 888, "hr", false))) + } + + test("SQL INSERT WITH SCHEMA EVOLUTION analysis failure aborts transaction") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql( + s"""CREATE TABLE $sourceNameAsString + |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin) + + val e = intercept[AnalysisException] { + sql( + s"""INSERT WITH SCHEMA EVOLUTION INTO $tableNameAsString + |SELECT nonexistent_col FROM $sourceNameAsString + |""".stripMargin) + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) + // schema must be unchanged after the aborted transaction + assert(table.schema.fieldNames.toSeq === Seq("pk", "salary", "dep")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala new file mode 100644 index 0000000000000..8acdd8242ef1f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTable} +import org.apache.spark.sql.sources + +class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase { + + private val newTableNameAsString = "cat.ns1.new_table" + + private def newTable: InMemoryRowLevelOperationTable = + catalog.loadTable(Identifier.of(Array("ns1"), "new_table")) + .asInstanceOf[InMemoryRowLevelOperationTable] + + test("CTAS with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val (txn, txnTables) = executeTransaction { + sql(s"""CREATE TABLE $newTableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(txnTables.size === 1) + assert(table.version() === "1") // source table: read-only, version unchanged + assert(newTable.version() === "1") // target table: newly created and written + + // the source table was scanned once through the transaction catalog with a dep='hr' filter + val sourceTxnTable = txnTables(tableNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + checkAnswer( + sql(s"SELECT * FROM $newTableNameAsString"), + Seq(Row(1, 100, "hr"))) + } + + test("CTAS from literal source with transactional checks") { + // no source catalog table involved — the query is a pure literal SELECT + val (txn, txnTables) = executeTransaction { + sql(s"CREATE TABLE $newTableNameAsString AS SELECT 1 AS pk, 100 AS salary, 'hr' AS dep") + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // literal SELECT - no catalog tables were scanned + assert(txnTables.isEmpty) + assert(newTable.version() === "1") // target table: newly created and written + + checkAnswer( + sql(s"SELECT * FROM $newTableNameAsString"), + Seq(Row(1, 100, "hr"))) + } + + test("CTAS with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val e = intercept[AnalysisException] { + sql(s"CREATE TABLE $newTableNameAsString AS SELECT nonexistent_col FROM $tableNameAsString") + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("RTAS with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // pre-create the target so REPLACE TABLE (not CREATE OR REPLACE) is valid + sql(s"CREATE TABLE $newTableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + + val (txn, txnTables) = executeTransaction { + sql(s"""REPLACE TABLE $newTableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(txnTables.size === 1) + assert(table.version() === "1") // source table: read-only, version unchanged + assert(newTable.version() === "1") // target table: replaced and written + + // the source table was scanned once through the transaction catalog with a dep='hr' filter + val sourceTxnTable = txnTables(tableNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + checkAnswer( + sql(s"SELECT * FROM $newTableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(3, 300, "hr"))) + } + + test("RTAS self-reference with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // source and target are the same table: reads the old snapshot via TxnTable, + // replaces the table with a filtered version + val (txn, txnTables) = executeTransaction { + sql(s"""CREATE OR REPLACE TABLE $tableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(txnTables.size === 1) + assert(table.version() === "1") // source/target table: replaced in place, version reset to 1 + + // the source/target table was scanned once with a dep='hr' filter + val sourceTxnTable = txnTables(tableNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(3, 300, "hr"))) + } + + test("RTAS with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val e = intercept[AnalysisException] { + sql(s"""CREATE OR REPLACE TABLE $tableNameAsString + |AS SELECT nonexistent_col FROM $tableNameAsString + |""".stripMargin) + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("simple CREATE TABLE and DROP TABLE do not create transactions") { + sql(s"CREATE TABLE $newTableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + assert(catalog.transaction === null) + assert(catalog.lastTransaction === null) + + sql(s"DROP TABLE $newTableNameAsString") + assert(catalog.transaction === null) + assert(catalog.lastTransaction === null) + } + + test("EXPLAIN CTAS with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"""EXPLAIN CREATE TABLE $newTableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + + // EXPLAIN should not start a transaction + assert(catalog.transaction === null) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala index 30890200df79d..fbcfdfb20c6ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala @@ -109,7 +109,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { collected = df.queryExecution.executedPlan.collect { case CommandResultExec( - _, AppendDataExec(_, _, write), + _, AppendDataExec(_, _, write, _), _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] assert(append.info.options.get("write.split-size") === "10") @@ -141,7 +141,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case AppendDataExec(_, _, write) => + case AppendDataExec(_, _, write, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] assert(append.info.options.get("write.split-size") === "10") } @@ -168,7 +168,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case AppendDataExec(_, _, write) => + case AppendDataExec(_, _, write, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] assert(append.info.options.get("write.split-size") === "10") } @@ -194,7 +194,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { collected = df.queryExecution.executedPlan.collect { case CommandResultExec( - _, OverwriteByExpressionExec(_, _, write), + _, OverwriteByExpressionExec(_, _, write, _), _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") @@ -227,7 +227,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case OverwritePartitionsDynamicExec(_, _, write) => + case OverwritePartitionsDynamicExec(_, _, write, _) => val dynOverwrite = write.toBatch.asInstanceOf[InMemoryBaseTable#DynamicOverwrite] assert(dynOverwrite.info.options.get("write.split-size") === "10") } @@ -254,7 +254,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { collected = df.queryExecution.executedPlan.collect { case CommandResultExec( - _, OverwriteByExpressionExec(_, _, write), + _, OverwriteByExpressionExec(_, _, write, _), _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") @@ -287,7 +287,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case OverwriteByExpressionExec(_, _, write) => + case OverwriteByExpressionExec(_, _, write, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") } @@ -317,7 +317,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case OverwriteByExpressionExec(_, _, write) => + case OverwriteByExpressionExec(_, _, write, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index 0f7f4cefe2feb..378a3af2561b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.connector -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions.CheckInvariant import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.connector.catalog.{Aborted, Committed} import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec, ReplaceDataExec, WriteDeltaExec} +import org.apache.spark.sql.sources abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { @@ -28,6 +30,10 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { protected def enforceCheckConstraintOnDelete: Boolean = true + // true when the table engine uses delta-based deletes (WriteDeltaExec), false for group-based + // (ReplaceDataExec); controls expected scan counts in transactional tests + protected def deltaDelete: Boolean = false + test("delete from table containing added column with default value") { createAndInitTable("pk INT NOT NULL, dep STRING", """{ "pk": 1, "dep": "hr" }""") @@ -659,6 +665,190 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { } } + test("delete with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val exception = intercept[AnalysisException] { + sql(s"DELETE FROM $tableNameAsString WHERE invalid_column = 1") + } + + assert(exception.getMessage.contains("invalid_column")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("delete with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // simple predicate delete: goes through SupportsDelete.deleteWhere (no Spark-side scan) + val (txn, _) = executeTransaction { + sql(s"DELETE FROM $tableNameAsString WHERE dep = 'hr'") + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(table.version() == "2") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(2, 200, "software"))) + } + + test("delete with subquery on source table and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + val (txn, txnTables) = executeTransaction { + sql( + s"""DELETE FROM $tableNameAsString + |WHERE pk IN (SELECT pk FROM $sourceNameAsString WHERE dep = 'hr') + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaDelete) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + val numSubquerySourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numSubquerySourceScans == expectedNumSourceScans) + + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaDelete) 1 else 2 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in subquery result) + } + + test("delete with CTE and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH cte AS ( + | SELECT pk FROM $sourceNameAsString WHERE dep = 'hr' + |) + |DELETE FROM $tableNameAsString + |WHERE pk IN (SELECT pk FROM cte) + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaDelete) 1 else 2 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaDelete) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numCteSourceScans == expectedNumSourceScans) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in source) + } + + test("delete using view with transactional checks") { + withView("temp_view") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + sql( + s"""CREATE VIEW temp_view AS + |SELECT pk FROM $sourceNameAsString WHERE dep = 'hr' + |""".stripMargin) + + val (txn, txnTables) = executeTransaction { + sql(s"DELETE FROM $tableNameAsString WHERE pk IN (SELECT pk FROM temp_view)") + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaDelete) 1 else 2 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaDelete) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in source) + } + } + + test("EXPLAIN DELETE SQL with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"EXPLAIN DELETE FROM $tableNameAsString WHERE dep = 'hr'") + + // EXPLAIN should not start a new transaction + assert(catalog.transaction === null) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) + } + private def executeDeleteWithFilters(query: String): Unit = { val executedPlan = executeAndKeepPlan { sql(query) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala index 9046123ddbd3f..56a689364d1c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala @@ -33,6 +33,8 @@ class DeltaBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { override def enforceCheckConstraintOnDelete: Boolean = false + override protected def deltaDelete: Boolean = true + test("delete handles metadata columns correctly") { createAndInitTable("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala index 89b42b5e6db7b..e821fc3f660da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.{AnalysisException, Row} abstract class DeltaBasedUpdateTableSuiteBase extends UpdateTableSuiteBase { + override protected def deltaUpdate: Boolean = true + test("nullable row ID attrs") { createAndInitTable("pk INT, salary INT, dep STRING", """{ "pk": 1, "salary": 300, "dep": 'hr' } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala index e1c574ec7ba65..d58e22e63d71e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.connector +import org.apache.spark.sql.{sources, Column, Row} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.Row import org.apache.spark.sql.classic.MergeIntoWriter -import org.apache.spark.sql.connector.catalog.Column +import org.apache.spark.sql.connector.catalog.{Aborted, Committed} import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.TableInfo import org.apache.spark.sql.functions._ @@ -31,6 +31,103 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { import testImplicits._ + private def targetTableCol(colName: String): Column = { + col(tableNameAsString + "." + colName) + } + + test("self merge with transactional checks") { + // create table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create a source on top of itself that will be fully resolved and analyzed + val sourceDF = spark.table(tableNameAsString) + .where("salary == 100") + .as("source") + sourceDF.queryExecution.assertAnalyzed() + + // merge into table using the source on top of itself + val (txn, txnTables) = executeTransaction { + sourceDF + .mergeInto( + tableNameAsString, + $"source.pk" === targetTableCol("pk") && targetTableCol("dep") === "hr") + .whenMatched() + .update(Map("salary" -> targetTableCol("salary").plus(1))) + .whenNotMatched() + .insertAll() + .merge() + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + assert(table.version() == "2") + + // check all table scans + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size == 4) + + // check table scans as MERGE target + val numTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numTargetScans == 2) + + // check table scans as MERGE source + val numSourceScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("salary", 100) => true + case _ => false + } + assert(numSourceScans == 2) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged + } + + for (alterClause <- Seq( + "ADD COLUMN new_col INT", + "DROP COLUMN salary", + "ALTER COLUMN salary TYPE BIGINT", + "ALTER COLUMN pk DROP NOT NULL")) + test(s"self merge fails when source schema changes after analysis - DDL: $alterClause" ) { + withTable(tableNameAsString) { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = spark.table(tableNameAsString).where("salary == 100").as("source") + sourceDF.queryExecution.assertAnalyzed() + + sql(s"ALTER TABLE $tableNameAsString $alterClause") + + val e = intercept[AnalysisException] { + sourceDF + .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk")) + .whenMatched() + .update(Map("salary" -> targetTableCol("salary").plus(1))) + .merge() + } + + assert( + e.getCondition == "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", + alterClause) + assert(catalog.lastTransaction.currentState == Aborted, alterClause) + assert(catalog.lastTransaction.isClosed, alterClause) + } + } + test("merge into empty table with NOT MATCHED clause") { withTempView("source") { createTable("pk INT NOT NULL, salary INT, dep STRING") @@ -979,6 +1076,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { } test("SPARK-54157: version is refreshed when source is V2 table") { + import org.apache.spark.sql.connector.catalog.Column val sourceTable = "cat.ns1.source_table" withTable(sourceTable) { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -1026,6 +1124,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { } test("SPARK-54444: any schema changes after analysis are prohibited") { + import org.apache.spark.sql.connector.catalog.Column val sourceTable = "cat.ns1.source_table" withTable(sourceTable) { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 7c0e503705c7e..0423fe66a2067 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, In, Not} import org.apache.spark.sql.catalyst.optimizer.BuildLeft -import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, InMemoryTable, TableInfo} +import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, InMemoryTable, TableInfo} import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.write.MergeSummary import org.apache.spark.sql.execution.SparkPlan @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.MergeRowsExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources import org.apache.spark.sql.types.{IntegerType, StringType} abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase @@ -38,6 +39,309 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase protected def deltaMerge: Boolean = false + test("self merge with transactional checks") { + // create table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // merge into table using a subquery on top of itself + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING (SELECT * FROM $tableNameAsString WHERE salary = 100) s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = t.salary + 1 + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + assert(table.version() == "2") + + // check all table scans + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumScans = if (deltaMerge) 2 else 4 + assert(targetTxnTable.scanEvents.size == expectedNumScans) + + // check table scans as MERGE target + val numTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumTargetScans = if (deltaMerge) 1 else 2 + assert(numTargetScans == expectedNumTargetScans) + + // check table scans as MERGE source + val numSourceScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("salary", 100) => true + case _ => false + } + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(numSourceScans == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged + } + + test("merge into table with analysis failure and transactional checks") { + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'support'), (4, 400, 'finance')") + + val exception = intercept[AnalysisException] { + sql( + s"""MERGE INTO $tableNameAsString t + |USING $sourceNameAsString s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET salary = s.invalid_column + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + assert(exception.getMessage.contains("invalid_column")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("merge into table using view with transactional checks") { + withView("temp_view") { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)") + sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)") + + // create view on top of source and target tables + sql( + s"""CREATE VIEW temp_view AS + |SELECT s.pk, s.salary, t.dep + |FROM $sourceNameAsString s + |LEFT JOIN ( + | SELECT * FROM $tableNameAsString WHERE pk < 10 + |) t ON s.pk = t.pk + |""".stripMargin) + + // merge into target table using view + val (txn, txnTables) = executeTransaction { + sql(s"""MERGE INTO $tableNameAsString t + |USING temp_view s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, dep = s.dep + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaMerge) 2 else 4 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as MERGE target (dep = 'hr') + val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumMergeTargetScans = if (deltaMerge) 1 else 2 + assert(numMergeTargetScans == expectedNumMergeTargetScans) + + // check target table scans in view as MERGE source (pk < 10) + val numViewTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.LessThan("pk", 10L) => true + case _ => false + } + val expectedNumViewTargetScans = if (deltaMerge) 1 else 2 + assert(numViewTargetScans == expectedNumViewTargetScans) + + // check source table scans in view as MERGE source (no predicate) + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged + Row(4, 400, "pending"))) // new + } + } + + test("merge into table using nested view with transactional checks") { + withView("base_view", "nested_view") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)") + sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)") + + // create base view + sql( + s"""CREATE VIEW base_view AS + |SELECT s.pk, s.salary, t.dep + |FROM $sourceNameAsString s + |LEFT JOIN ( + | SELECT * FROM $tableNameAsString WHERE pk < 10 + |) t ON s.pk = t.pk + |""".stripMargin) + + // create nested view on top of base view + sql( + s"""CREATE VIEW nested_view AS + |SELECT * FROM base_view + |""".stripMargin) + + // merge into target table using nested view + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING nested_view s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, dep = s.dep + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaMerge) 2 else 4 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as MERGE target (dep = 'hr') + val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumMergeTargetScans = if (deltaMerge) 1 else 2 + assert(numMergeTargetScans == expectedNumMergeTargetScans) + + // check target table scans in view as MERGE source (pk < 10) + val numViewTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.LessThan("pk", 10L) => true + case _ => false + } + val expectedNumViewTargetScans = if (deltaMerge) 1 else 2 + assert(numViewTargetScans == expectedNumViewTargetScans) + + // check source table scans in view as MERGE source (no predicate) + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged + Row(4, 400, "pending"))) // new + } + } + } + + test("merge into table rewritten as INSERT with transactional checks") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT, value STRING, dep STRING", + """{ "pk": 1, "value": "a", "dep": "hr" } + |{ "pk": 2, "value": "b", "dep": "finance" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT, value STRING, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (3, 'c', 'hr'), (4, 'd', 'software')") + + // merge into target with only WHEN NOT MATCHED clauses (rewritten as insert) + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING $sourceNameAsString s + |ON t.pk = s.pk + |WHEN NOT MATCHED AND s.pk < 4 THEN + | INSERT (pk, value, dep) VALUES (s.pk, concat(s.value, '_low'), s.dep) + |WHEN NOT MATCHED AND s.pk >= 4 THEN + | INSERT (pk, value, dep) VALUES (s.pk, concat(s.value, '_high'), s.dep) + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size == 1) + + // check source table was scanned correctly + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size == 1) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, "a", "hr"), // unchanged + Row(2, "b", "finance"), // unchanged + Row(3, "c_low", "hr"), // inserted via first NOT MATCHED clause + Row(4, "d_high", "software"))) // inserted via second NOT MATCHED clause + } + } + test("merge into table with expression-based default values") { val columns = Array( Column.create("pk", IntegerType), @@ -671,6 +975,131 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("merge with CTE with transactional checks") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // merge into target table using CTE + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH cte AS ( + | SELECT pk, salary + 50 AS salary, dep + | FROM $sourceNameAsString + | WHERE salary > 100 + |) + |MERGE INTO $tableNameAsString t + |USING cte s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaMerge) 1 else 2 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as MERGE target (dep = 'hr') + val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numMergeTargetScans == expectedNumTargetScans) + + // check source table was scanned correctly + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check source table scans in CTE (salary > 100) + val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.GreaterThan("salary", 100) => true + case _ => false + } + assert(numCteSourceScans == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 200, "hr"), // updated (150 + 50) + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged + Row(4, 450, "pending"))) // inserted (400 + 50) + } + } + + test("merge with cached source and transactional checks") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create and populate source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'support'), (4, 400, 'finance')") + + // Cache source table before the transaction. Make sure when the transation is active the + // catalog still creates a transaction table. + spark.table(sourceNameAsString).cache() + + try { + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING $sourceNameAsString s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, dep = s.dep + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + // both target and source must have been read through the transaction catalog + assert(txnTables.size == 2) + assert(table.version() == "2") + assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) + assert(txnTables(tableNameAsString).scanEvents.nonEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "support"), // matched and updated + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged (no match in source) + Row(4, 400, "pending"))) // not matched, inserted + } finally { + spark.catalog.uncacheTable(sourceNameAsString) + } + } + } + test("merge with subquery as source") { withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -2223,6 +2652,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sql(query) } assert(e.getMessage.contains("ON search condition of the MERGE statement")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) } private def assertMetric( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 8c51fb17b2cf4..565e209b021a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -28,16 +28,16 @@ import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expr import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReplaceData, WriteDelta} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, TableInfo, Update, Write} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, Table, TableInfo, Txn, TxnTable, Update, Write} import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan} +import org.apache.spark.sql.connector.write.RowLevelOperationTable +import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType, StructField, StructType} -import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -82,6 +82,7 @@ abstract class RowLevelOperationSuiteBase protected val namespace: Array[String] = Array("ns1") protected val ident: Identifier = Identifier.of(namespace, "test_table") protected val tableNameAsString: String = "cat." + ident.toString + protected val sourceNameAsString: String = "cat.ns1.source_table" protected def extraTableProps: java.util.Map[String, String] = { Collections.emptyMap[String, String] @@ -135,22 +136,28 @@ abstract class RowLevelOperationSuiteBase // executes an operation and keeps the executed plan protected def executeAndKeepPlan(func: => Unit): SparkPlan = { - var executedPlan: SparkPlan = null + withQueryExecutionsCaptured(spark)(func) match { + case Seq(qe) => stripAQEPlan(qe.executedPlan) + case other => fail(s"expected only one query execution, but got ${other.size}") + } + } - val listener = new QueryExecutionListener { - override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - executedPlan = qe.executedPlan - } - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + protected def executeTransaction(func: => Unit): (Txn, Map[String, TxnTable]) = { + val tables = withQueryExecutionsCaptured(spark)(func).flatMap { qe => + collectWithSubqueries(qe.executedPlan) { + case BatchScanExec(_, _, _, _, table: TxnTable, _) => table + case BatchScanExec(_, _, _, _, RowLevelOperationTable(table: TxnTable, _), _) => table } } - spark.listenerManager.register(listener) - - func - - sparkContext.listenerBus.waitUntilEmpty() + (catalog.lastTransaction, indexByName(tables)) + } - stripAQEPlan(executedPlan) + protected def indexByName[T <: Table](tables: Seq[T]): Map[String, T] = { + tables.groupBy(_.name).map { + case (name, sameNameTables) => + val Seq(table) = sameNameTables.distinct + name -> table + } } // executes an operation and extracts conditions from ReplaceData or WriteDelta diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index ac0bf3bdba9ce..2ca4235f96b55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.connector import org.apache.spark.SparkRuntimeException -import org.apache.spark.sql.Row -import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, TableChange, TableInfo} +import org.apache.spark.sql.{sources, AnalysisException, Row} +import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, TableChange, TableInfo} import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StringType} @@ -28,6 +28,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { import testImplicits._ + protected def deltaUpdate: Boolean = false + test("update table containing added column with default value") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", """{ "pk": 1, "salary": 100, "dep": "hr" } @@ -762,4 +764,325 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { Row(5))) } } + + test("update with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val exception = intercept[AnalysisException] { + sql(s"UPDATE $tableNameAsString SET invalid_column = -1") + } + + assert(exception.getMessage.contains("invalid_column")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("update with CTE and transactional checks") { + // create table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // update using CTE + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH cte AS ( + | SELECT pk, salary + 50 AS adjusted_salary, dep + | FROM $sourceNameAsString + | WHERE salary > 100 + |) + |UPDATE $tableNameAsString t + |SET salary = -1 + |WHERE t.dep = 'hr' AND EXISTS (SELECT 1 FROM cte WHERE cte.pk = t.pk) + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaUpdate) 1 else 3 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans for UPDATE condition (dep = 'hr') + val numUpdateTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numUpdateTargetScans == expectedNumTargetScans) + + // check source table was scanned correctly + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaUpdate) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check source table scans in CTE (salary > 100) + val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.GreaterThan("salary", 100) => true + case _ => false + } + assert(numCteSourceScans == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (no matching pk in source) + } + + test("update with subquery on source table and transactional checks") { + // create target table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // update using an uncorrelated IN subquery that reads from a transactional catalog table + val (txn, txnTables) = executeTransaction { + sql( + s"""UPDATE $tableNameAsString + |SET salary = -1 + |WHERE pk IN (SELECT pk FROM $sourceNameAsString WHERE dep = 'hr') + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check source table was scanned correctly (dep = 'hr' filter in the subquery) + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaUpdate) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + val numSubquerySourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numSubquerySourceScans == expectedNumSourceScans) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaUpdate) 1 else 3 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated (pk 1 is in subquery result) + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in subquery result) + } + + test("update with uncorrelated scalar subquery on source table and transactional checks") { + // create target table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // update using an uncorrelated scalar subquery in the SET clause that reads from a + // transactional catalog table; scalar subqueries are executed as SubqueryExec at runtime + // and cannot be rewritten as joins + val (txn, txnTables) = executeTransaction { + sql( + s"""UPDATE $tableNameAsString + |SET salary = (SELECT max(salary) FROM $sourceNameAsString WHERE dep = 'hr') + |WHERE dep = 'hr' + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check source table was scanned via the transaction catalog + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.nonEmpty) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + // check target table was scanned via the transaction catalog + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.nonEmpty) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "hr"), // updated (max salary in source for 'hr' is 150) + Row(2, 200, "software"), // unchanged + Row(3, 150, "hr"))) // updated + } + + test("update with constraint violation and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val exception = intercept[SparkRuntimeException] { + executeTransaction { + sql( + s"""UPDATE $tableNameAsString + |SET pk = NULL + |WHERE dep = 'hr' + |""".stripMargin) // NULL violates NOT NULL constraint + } + } + + assert(exception.getMessage.contains("NOT_NULL_ASSERT_VIOLATION")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("update using view with transactional checks") { + withView("temp_view") { + // create target table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)") + sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)") + + // create view on top of source and target tables + sql( + s"""CREATE VIEW temp_view AS + |SELECT s.pk, s.salary, t.dep + |FROM $sourceNameAsString s + |LEFT JOIN ( + | SELECT * FROM $tableNameAsString WHERE pk < 10 + |) t ON s.pk = t.pk + |""".stripMargin) + + // update target table using view + val (txn, txnTables) = executeTransaction { + sql( + s"""UPDATE $tableNameAsString t + |SET salary = -1 + |WHERE t.dep = 'hr' AND EXISTS (SELECT 1 FROM temp_view v WHERE v.pk = t.pk) + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaUpdate) 2 else 7 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as UPDATE target (dep = 'hr') + val numUpdateTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumUpdateTargetScans = if (deltaUpdate) 1 else 3 + assert(numUpdateTargetScans == expectedNumUpdateTargetScans) + + // check target table scans in view as source (pk < 10) + val numViewTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.LessThan("pk", 10L) => true + case _ => false + } + val expectedNumViewTargetScans = if (deltaUpdate) 1 else 4 + assert(numViewTargetScans == expectedNumViewTargetScans) + + // check source table scans in view + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaUpdate) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated from view + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (no matching pk in source) + } + } + + test("df.explain() on update with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // sql() is lazy, but explain() forces executedPlan. + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'").explain() + + assert(catalog.lastTransaction != null) + assert(catalog.lastTransaction.currentState == Committed) + assert(catalog.lastTransaction.isClosed) + assert(table.version() == "2") + + // The UPDATE was actually executed, not just planned. + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated + Row(2, 200, "software"))) // unchanged + } + + test("EXPLAIN UPDATE SQL with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // EXPLAIN UPDATE only plans the command, it does not execute the write. + sql(s"EXPLAIN UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + // A transaction should not have started at all. + assert(catalog.transaction === null) + + // The UPDATE was not executed. Data is unchanged. + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala new file mode 100644 index 0000000000000..141d5966b4b6c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import scala.concurrent.duration._ + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.classic +import org.apache.spark.sql.execution.QueryExecution + +/** + * Benchmark to measure the overhead of cloning the analyzer for transactional query execution. + * Each transactional query creates a new [[Analyzer]] instance via + * [[Analyzer.withCatalogManager]], which shares all rules with the original but carries a + * per-query [[org.apache.spark.sql.connector.catalog.CatalogManager]]. This benchmark checks + * whether the cloning introduces measurable overhead. + * + * To run this benchmark: + * {{{ + * build/sbt "sql/Test/runMain " + * }}} + */ +object AnalyzerBenchmark extends SqlBasedBenchmark { + + private val numRows = 100 + private val queries = Seq( + "simple select" -> "SELECT id, val FROM t1", + "join" -> "SELECT t1.id, t2.val FROM t1 JOIN t2 ON t1.id = t2.id", + "wide schema" -> s"SELECT ${(1 to 100).map(i => s"col_$i").mkString(", ")} FROM wide_t" + ) + + private def setupTables(): Unit = { + spark.range(numRows).selectExpr("id", "id * 2 as val").createOrReplaceTempView("t1") + spark.range(numRows).selectExpr("id", "id * 3 as val").createOrReplaceTempView("t2") + spark.range(numRows) + .selectExpr((1 to numRows).map(i => s"id as col_$i"): _*) + .createOrReplaceTempView("wide_t") + } + + /** + * Measures analysis time for a pre-parsed plan, comparing the session analyzer against a + * cloned analyzer created via [[Analyzer.withCatalogManager]]. + * + * Two cases: + * - "session analyzer" : baseline, uses the session analyzer directly. + * - "cloned analyzer (per query)": analyzer cloned every iteration; reflects the full + * per-transactional-query cost (clone + analysis). + */ + def analysisBenchmark(): Unit = { + for ((name, sql) <- queries) { + runBenchmark(s"analysis overhead $name") { + val benchmark = new Benchmark( + name = s"analysis overhead $name", + // Per row measurements are not meaningful here. + valuesPerIteration = numRows, + minTime = 10.seconds, + output = output) + val catalogManager = spark.sessionState.catalogManager + + benchmark.addCase("session analyzer") { _ => + val plan = spark.sessionState.sqlParser.parsePlan(sql) + new QueryExecution(spark.asInstanceOf[classic.SparkSession], plan).analyzed + } + + benchmark.addCase("cloned analyzer (per query)") { _ => + val cloned = spark.sessionState.analyzer.withCatalogManager(catalogManager) + val plan = spark.sessionState.sqlParser.parsePlan(sql) + new QueryExecution(spark.asInstanceOf[classic.SparkSession], + plan, analyzerOpt = Some(cloned)).analyzed + } + + benchmark.run() + } + } + } + + /** + * Micro-benchmark for [[Analyzer.withCatalogManager]] in isolation: measures the cost of + * instantiating the anonymous [[Analyzer]] subclass, independent of analysis work. + */ + def cloneCostBenchmark(): Unit = { + runBenchmark("analyzer clone cost") { + val numRows = 1 // Per row measurements are not meaningful here. + val benchmark = new Benchmark( + name = "analyzer clone cost", + valuesPerIteration = numRows, + output = output) + val catalogManager = spark.sessionState.catalogManager + + benchmark.addCase("withCatalogManager") { _ => + spark.sessionState.analyzer.withCatalogManager(catalogManager) + } + + benchmark.run() + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + setupTables() + cloneCostBenchmark() + analysisBenchmark() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala index 9996bec44e232..bdf47dc96c3a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala @@ -21,8 +21,10 @@ import org.apache.spark.SparkConf import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.plans.logical.CompoundBody import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLConf +import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTableCatalog, Txn, TxnTable, TxnTableCatalog} import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -47,6 +49,27 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession { } // Helpers + private def withCatalog( + name: String)( + f: InMemoryRowLevelOperationTableCatalog => Unit): Unit = { + withSQLConf(s"spark.sql.catalog.$name" -> + classOf[InMemoryRowLevelOperationTableCatalog].getName) { + val catalog = spark.sessionState.catalogManager + .catalog(name) + .asInstanceOf[InMemoryRowLevelOperationTableCatalog] + f(catalog) + } + } + + private def loadTxnTable( + txn: Txn, + tableName: String, + namespace: Array[String] = Array("ns1")): TxnTable = + txn.catalog + .asInstanceOf[TxnTableCatalog] + .loadTable(Identifier.of(namespace, tableName)) + .asInstanceOf[TxnTable] + private def verifySqlScriptResult( sqlText: String, expected: Seq[Row], @@ -174,6 +197,188 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession { } } + test("multi statement with transactional checks - insert then delete") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software'); + | DELETE FROM cat.ns1.t + | WHERE pk IN (SELECT pk FROM cat.ns1.t WHERE dep = 'hr'); + | SELECT * FROM cat.ns1.t; + |END + |""".stripMargin + + verifySqlScriptResult(sqlScript, Seq(Row(2, 200, "software"))) + + // Each DML statement in a script runs in its own independent QE and transaction. + assert(catalog.seenTransactions.size === 2) + assert(catalog.seenTransactions.forall(t => + t.currentState === Committed && t.isClosed)) + + // The DELETE subquery scans the table with a dep='hr' predicate; verify it was tracked. + val deleteTxnTable = loadTxnTable(catalog.seenTransactions(1), "t") + assert(deleteTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + } + } + } + + test("multi statement with transactional checks - second statement fails") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software'); + | DELETE FROM cat.ns1.t WHERE nonexistent_column = 1; + |END + |""".stripMargin + + checkError( + exception = intercept[AnalysisException] { + spark.sql(sqlScript).collect() + }, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map( + "objectName" -> "`nonexistent_column`", + "proposal" -> ".*"), + matchPVals = true, + queryContext = Array(ExpectedContext("nonexistent_column"))) + + // INSERT committed; DELETE was aborted because analysis failed on the bad column. + assert(catalog.seenTransactions.size === 2) + assert(catalog.seenTransactions(0).currentState === Committed) + assert(catalog.seenTransactions(0).isClosed) + assert(catalog.seenTransactions(1).currentState === Aborted) + assert(catalog.seenTransactions(1).isClosed) + assert(catalog.lastTransaction.currentState === Aborted) + } + } + } + + test("multi statement with transactional checks - insert, merge, update") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t", "cat.ns1.src") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | CREATE TABLE cat.ns1.src (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software'), (3, 300, 'hr'); + | INSERT INTO cat.ns1.src VALUES (1, 150, 'hr'), (4, 400, 'finance'); + | MERGE INTO cat.ns1.t AS t + | USING cat.ns1.src AS s + | ON t.pk = s.pk + | WHEN MATCHED THEN UPDATE SET salary = s.salary + | WHEN NOT MATCHED THEN INSERT (pk, salary, dep) + | VALUES (s.pk, s.salary, s.dep); + | UPDATE cat.ns1.t SET salary = salary + 50 WHERE dep = 'software'; + | SELECT * FROM cat.ns1.t ORDER BY pk; + |END + |""".stripMargin + + verifySqlScriptResult( + sqlScript, + Seq( + Row(1, 150, "hr"), + Row(2, 250, "software"), + Row(3, 300, "hr"), + Row(4, 400, "finance"))) + + // INSERT (x2), MERGE, and UPDATE each run in their own independent QE and transaction. + assert(catalog.seenTransactions.size === 4) + assert(catalog.seenTransactions.forall(t => t.currentState === Committed && t.isClosed)) + + def txnTable(txnIdx: Int): TxnTable = + loadTxnTable(catalog.seenTransactions(txnIdx), "t") + + // Both inserts are pure writes - no scan. + assert(txnTable(0).scanEvents.isEmpty) + assert(txnTable(1).scanEvents.isEmpty) + + // MERGE scans the full target table. The join is on pk (not the partition column). + assert(txnTable(2).scanEvents.nonEmpty) + assert(txnTable(2).scanEvents.flatten.isEmpty) + + // UPDATE with WHERE dep='software' pushes an equality predicate on the partition column. + assert(txnTable(3).scanEvents.flatten.exists { + case sources.EqualTo("dep", "software") => true + case _ => false + }) + } + } + } + + test("loop with transactional checks - each iteration runs in its own transaction") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | DECLARE i INT = 1; + | CREATE TABLE + | cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | WHILE i <= 3 DO + | INSERT INTO cat.ns1.t VALUES (i, i * 100, 'hr'); + | SET i = i + 1; + | END WHILE; + | SELECT * FROM cat.ns1.t ORDER BY pk; + |END + |""".stripMargin + + verifySqlScriptResult( + sqlScript, + Seq(Row(1, 100, "hr"), Row(2, 200, "hr"), Row(3, 300, "hr"))) + + // Each loop iteration's INSERT runs in its own independent transaction. + assert(catalog.seenTransactions.size === 3) + assert(catalog.seenTransactions.forall(t => t.currentState === Committed && t.isClosed)) + } + } + } + + test("continue handler with transactional checks - handler DML runs in its own transaction") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | DECLARE CONTINUE HANDLER FOR DIVIDE_BY_ZERO + | BEGIN + | INSERT INTO cat.ns1.t VALUES (-1, -1, 'error'); + | END; + | CREATE TABLE + | cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'); + | SELECT 1/0; + | INSERT INTO cat.ns1.t VALUES (2, 200, 'software'); + | SELECT * FROM cat.ns1.t ORDER BY pk; + |END + |""".stripMargin + + verifySqlScriptResult( + sqlScript, + Seq(Row(-1, -1, "error"), Row(1, 100, "hr"), Row(2, 200, "software"))) + + // INSERT(1), handler INSERT(-1), INSERT(2) - each in its own transaction. + assert(catalog.seenTransactions.size === 3) + assert(catalog.seenTransactions.forall(t => t.currentState === Committed && t.isClosed)) + } + } + } + test("script without result statement") { val sqlScript = """