diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java new file mode 100644 index 0000000000000..daa3176dcbba5 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.transactions.Transaction; +import org.apache.spark.sql.connector.catalog.transactions.TransactionInfo; + +/** + * A {@link CatalogPlugin} that supports transactions. + *
+ * Catalogs that implement this interface opt in to transactional query execution. A catalog + * implementing this interface is responsible for starting transactions. + * + * @since 4.2.0 + */ +@Evolving +public interface TransactionalCatalogPlugin extends CatalogPlugin { + + /** + * Begins a new transaction and returns a {@link Transaction} representing it. + * + * @param info metadata about the transaction being started. + */ + Transaction beginTransaction(TransactionInfo info); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java new file mode 100644 index 0000000000000..77044c6202fbe --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.transactions; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin; + +import java.io.Closeable; + +/** + * Represents a transaction. + *
+ * Spark begins a transaction with {@link TransactionalCatalogPlugin#beginTransaction} and + * executes read/write operations against the transaction's catalog. On success, Spark + * calls {@link #commit()}; on failure, Spark calls {@link #abort()}. In both cases Spark + * subsequently calls {@link #close()} to release resources. + * + * @since 4.2.0 + */ +@Evolving +public interface Transaction extends Closeable { + + /** + * Returns the catalog associated with this transaction. This catalog is responsible for tracking + * read/write operations that occur within the boundaries of a transaction. This allows + * connectors to perform conflict resolution at commit time. + */ + CatalogPlugin catalog(); + + /** + * Commits the transaction. All writes performed under it become visible to other readers. + *
+ * The connector is responsible for detecting and resolving conflicting commits or throwing + * an exception if resolution is not possible. + *
+ * This method will be called exactly once per transaction. Spark calls {@link #close()} + * immediately after this method returns. + * + * @throws IllegalStateException if the transaction has already been committed or aborted. + */ + void commit(); + + /** + * Aborts the transaction, discarding any staged changes. + *
+ * This method must be idempotent. If the transaction has already been committed or aborted, + * invoking it must have no effect. + *
+ * Spark calls {@link #close()} immediately after this method returns. + */ + void abort(); + + /** + * Releases any resources held by this transaction. + *
+ * Spark always calls this method after {@link #commit()} or {@link #abort()}, regardless of + * whether those methods succeed or not. + *
+ * This method must be idempotent. If the transaction has already been closed,
+ * invoking it must have no effect.
+ */
+ @Override
+ void close();
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java
new file mode 100644
index 0000000000000..3e6979cec469f
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.transactions;
+
+import org.apache.spark.annotation.Evolving;
+
+/**
+ * Metadata about a transaction.
+ *
+ * @since 4.2.0
+ */
+@Evolving
+public interface TransactionInfo {
+ /**
+ * Returns a unique identifier for this transaction.
+ */
+ String id();
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 3b4d725840935..09928553839ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -336,6 +336,30 @@ class Analyzer(
}
}
+ /**
+ * Returns a copy of this analyzer that uses the given [[CatalogManager]] for all catalog
+ * lookups. All other configuration (extended rules, checks, etc.) is preserved. Used by
+ * [[QueryExecution]] to create a per-query analyzer for transactional operations for
+ * transaction-aware catalog resolution.
+ */
+ def withCatalogManager(newCatalogManager: CatalogManager): Analyzer = {
+ val self = this
+ new Analyzer(newCatalogManager, sharedRelationCache) {
+ override val hintResolutionRules: Seq[Rule[LogicalPlan]] = self.hintResolutionRules
+ override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = self.extendedResolutionRules
+ override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = self.postHocResolutionRules
+ override val extendedCheckRules: Seq[LogicalPlan => Unit] = self.extendedCheckRules
+ override val singlePassResolverExtensions: Seq[ResolverExtension] =
+ self.singlePassResolverExtensions
+ override val singlePassMetadataResolverExtensions: Seq[ResolverExtension] =
+ self.singlePassMetadataResolverExtensions
+ override val singlePassPostHocResolutionRules: Seq[Rule[LogicalPlan]] =
+ self.singlePassPostHocResolutionRules
+ override val singlePassExtendedResolutionChecks: Seq[LogicalPlan => Unit] =
+ self.singlePassExtendedResolutionChecks
+ }
+ }
+
override def execute(plan: LogicalPlan): LogicalPlan = {
AnalysisContext.withNewAnalysisContext {
executeSameContext(plan)
@@ -437,7 +461,9 @@ class Analyzer(
Batch("Simple Sanity Check", Once,
LookupFunctions),
Batch("Keep Legacy Outputs", Once,
- KeepLegacyOutputs)
+ KeepLegacyOutputs),
+ Batch("Unresolve Relations", Once,
+ new UnresolveTransactionRelations(catalogManager))
)
override def batches: Seq[Batch] = earlyBatches ++ Seq(
@@ -1005,9 +1031,10 @@ class Analyzer(
}
}
- // Resolve V2TableReference nodes in a plan. V2TableReference is only created for temp views
- // (via V2TableReference.createForTempView), so we only need to resolve it when returning
- // the plan of temp views (in resolveViews and unwrapRelationPlan).
+ // Resolve V2TableReference nodes created for:
+ // 1 Temp views (via createForTempView).
+ // 2. Transaction references (via createForTransaction). These are resolved by a
+ // separate analysis batch in the transaction-aware analyzer instance.
private def resolveTableReferences(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsUp {
case r: V2TableReference => relationResolution.resolveReference(r)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala
index e86248febd2eb..fd0394098b0ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala
@@ -391,6 +391,11 @@ class RelationResolution(
}
private def getOrLoadRelation(ref: V2TableReference): LogicalPlan = {
+ // Skip cache when a transaction is active.
+ if (catalogManager.transaction.isDefined) {
+ return loadRelation(ref)
+ }
+
val key = toCacheKey(ref.catalog, ref.identifier)
relationCache.get(key) match {
case Some(cached) =>
@@ -403,9 +408,18 @@ class RelationResolution(
}
private def loadRelation(ref: V2TableReference): LogicalPlan = {
- val table = ref.catalog.loadTable(ref.identifier)
+ // Resolve catalog. When a transaction is active we return the transaction
+ // aware catalog instance.
+ val resolvedCatalog = catalogManager.catalog(ref.catalog.name).asTableCatalog
+ val table = resolvedCatalog.loadTable(ref.identifier)
V2TableReferenceUtils.validateLoadedTable(table, ref)
- ref.toRelation(table)
+ // Create relation with resolved Catalog.
+ DataSourceV2Relation(
+ table = table,
+ output = ref.output,
+ catalog = Some(resolvedCatalog),
+ identifier = Some(ref.identifier),
+ options = ref.options)
}
private def adaptCachedRelation(cached: LogicalPlan, ref: V2TableReference): LogicalPlan = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala
new file mode 100644
index 0000000000000..0e344173d7892
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TransactionalWrite}
+import org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.allowInvokingTransformsInAnalyzer
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog}
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+
+/**
+ * When a transaction is active, converts resolved [[DataSourceV2Relation]] nodes back to
+ * [[V2TableReference]] placeholders for all relations loaded by a catalog with the same
+ * name as the transaction catalog.
+ *
+ * This forces re-resolution of those relations against the transaction's catalog, which
+ * intercepts [[TableCatalog#loadTable]] calls to track which tables are read as part of
+ * the transaction.
+ */
+class UnresolveTransactionRelations(val catalogManager: CatalogManager)
+ extends Rule[LogicalPlan] with LookupCatalog {
+
+ override def apply(plan: LogicalPlan): LogicalPlan =
+ catalogManager.transaction match {
+ case Some(transaction) =>
+ allowInvokingTransformsInAnalyzer {
+ plan.transform {
+ case tw: TransactionalWrite =>
+ unresolveRelations(tw, transaction.catalog)
+ }
+ }
+ case _ => plan
+ }
+
+ private def unresolveRelations(
+ plan: LogicalPlan,
+ catalog: CatalogPlugin): LogicalPlan = {
+ plan transform {
+ case r: DataSourceV2Relation if isLoadedFromCatalog(r, catalog) =>
+ V2TableReference.createForTransaction(r)
+ }
+ }
+
+ private def isLoadedFromCatalog(
+ relation: DataSourceV2Relation,
+ catalog: CatalogPlugin): Boolean = {
+ relation.catalog.exists(_.name == catalog.name) && relation.identifier.isDefined
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala
index 85c36d452b309..f459706a690bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.V2TableReference.Context
import org.apache.spark.sql.catalyst.analysis.V2TableReference.TableInfo
import org.apache.spark.sql.catalyst.analysis.V2TableReference.TemporaryViewContext
+import org.apache.spark.sql.catalyst.analysis.V2TableReference.TransactionContext
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.catalyst.plans.logical.Statistics
@@ -37,7 +38,7 @@ import org.apache.spark.sql.connector.catalog.V2TableUtil
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.util.CaseInsensitiveStringMap
-import org.apache.spark.sql.util.SchemaValidationMode.ALLOW_NEW_TOP_LEVEL_FIELDS
+import org.apache.spark.sql.util.SchemaValidationMode.{ALLOW_NEW_TOP_LEVEL_FIELDS, PROHIBIT_CHANGES}
import org.apache.spark.util.ArrayImplicits._
/**
@@ -84,11 +85,19 @@ private[sql] object V2TableReference {
sealed trait Context
case class TemporaryViewContext(viewName: Seq[String]) extends Context
+ /** Context for relations that are re-resolved through a transaction catalog. */
+ case object TransactionContext extends Context
def createForTempView(relation: DataSourceV2Relation, viewName: Seq[String]): V2TableReference = {
create(relation, TemporaryViewContext(viewName))
}
+ // V2TableReference nodes in the transaction context are produced by
+ // UnresolveTransactionRelations which unresolves already resolved relations.
+ def createForTransaction(relation: DataSourceV2Relation): V2TableReference = {
+ create(relation, TransactionContext)
+ }
+
private def create(relation: DataSourceV2Relation, context: Context): V2TableReference = {
val ref = V2TableReference(
relation.catalog.get.asTableCatalog,
@@ -110,11 +119,32 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper {
ref.context match {
case ctx: TemporaryViewContext =>
validateLoadedTableInTempView(table, ref, ctx)
+ case TransactionContext =>
+ validateLoadedTableInTransaction(table, ref)
case ctx =>
throw SparkException.internalError(s"Unknown table ref context: ${ctx.getClass.getName}")
}
}
+ private def validateLoadedTableInTransaction(table: Table, ref: V2TableReference): Unit = {
+ // Do not allow schema evolution to pre-analysed dataframes that are later used in
+ // transactional writes. This is because the entire plans was built based on the original schema
+ // and any schema change would make the plan structurally invalid. This is inline with the
+ // semantics of SPARK-54444.
+ val dataErrors = V2TableUtil.validateCapturedColumns(
+ table = table,
+ originCols = ref.info.columns,
+ mode = PROHIBIT_CHANGES)
+ if (dataErrors.nonEmpty) {
+ throw QueryCompilationErrors.columnsChangedAfterAnalysis(ref.name, dataErrors)
+ }
+
+ val metaErrors = V2TableUtil.validateCapturedMetadataColumns(table, ref.info.metadataColumns)
+ if (metaErrors.nonEmpty) {
+ throw QueryCompilationErrors.metadataColumnsChangedAfterAnalysis(ref.name, metaErrors)
+ }
+ }
+
private def validateLoadedTableInTempView(
table: Table,
ref: V2TableReference,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
index c38377582c156..774c783ecf8a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
@@ -188,7 +188,11 @@ case class InsertIntoStatement(
byName: Boolean = false,
replaceCriteriaOpt: Option[InsertReplaceCriteria] = None,
withSchemaEvolution: Boolean = false)
- extends UnaryParsedStatement {
+ // Extends TransactionalWrite so that QueryExecution can detect a potential transaction on the
+ // unresolved logical plan before analysis runs. InsertIntoStatement is shared between V1 and V2
+ // inserts, but the LookupCatalog.TransactionalWrite extractor only matches when the target
+ // catalog implements TransactionalCatalogPlugin, so V1 inserts are never assigned a transaction.
+ extends UnaryParsedStatement with TransactionalWrite {
require(overwrite || !ifPartitionNotExists,
"IF NOT EXISTS is only valid in INSERT OVERWRITE")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 500d648d23acb..f9e7774e1448d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -156,7 +156,7 @@ case class AppendData(
isByName: Boolean,
withSchemaEvolution: Boolean,
write: Option[Write] = None,
- analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand {
+ analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand with TransactionalWrite {
override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT)
override def withNewQuery(newQuery: LogicalPlan): AppendData = copy(query = newQuery)
override def withNewTable(newTable: NamedRelation): AppendData = copy(table = newTable)
@@ -204,7 +204,7 @@ case class OverwriteByExpression(
isByName: Boolean,
withSchemaEvolution: Boolean,
write: Option[Write] = None,
- analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand {
+ analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand with TransactionalWrite {
override val writePrivileges: Set[TableWritePrivilege] =
Set(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)
override lazy val resolved: Boolean = {
@@ -264,7 +264,7 @@ case class OverwritePartitionsDynamic(
writeOptions: Map[String, String],
isByName: Boolean,
withSchemaEvolution: Boolean,
- write: Option[Write] = None) extends V2WriteCommand {
+ write: Option[Write] = None) extends V2WriteCommand with TransactionalWrite {
override val writePrivileges: Set[TableWritePrivilege] =
Set(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)
override def withNewQuery(newQuery: LogicalPlan): OverwritePartitionsDynamic = {
@@ -520,8 +520,10 @@ case class WriteDelta(
trait V2CreateTableAsSelectPlan
extends V2CreateTablePlan
with AnalysisOnlyCommand
- with CTEInChildren {
+ with CTEInChildren
+ with TransactionalWrite {
def query: LogicalPlan
+ override def table: LogicalPlan = name
override def withCTEDefs(cteDefs: Seq[CTERelationDef]): LogicalPlan = {
withNameAndQuery(newName = name, newQuery = WithCTE(query, cteDefs))
@@ -943,7 +945,8 @@ object DescribeColumn {
*/
case class DeleteFromTable(
table: LogicalPlan,
- condition: Expression) extends UnaryCommand with SupportsSubquery {
+ condition: Expression)
+ extends UnaryCommand with TransactionalWrite with SupportsSubquery {
override def child: LogicalPlan = table
override protected def withNewChildInternal(newChild: LogicalPlan): DeleteFromTable =
copy(table = newChild)
@@ -965,7 +968,8 @@ case class DeleteFromTableWithFilters(
case class UpdateTable(
table: LogicalPlan,
assignments: Seq[Assignment],
- condition: Option[Expression]) extends UnaryCommand with SupportsSubquery {
+ condition: Option[Expression])
+ extends UnaryCommand with TransactionalWrite with SupportsSubquery {
lazy val aligned: Boolean = AssignmentUtils.aligned(table.output, assignments)
@@ -998,8 +1002,13 @@ case class MergeIntoTable(
notMatchedActions: Seq[MergeAction],
notMatchedBySourceActions: Seq[MergeAction],
withSchemaEvolution: Boolean)
- extends BinaryCommand with WriteWithSchemaEvolution with SupportsSubquery {
+ extends BinaryCommand
+ with WriteWithSchemaEvolution
+ with SupportsSubquery
+ with TransactionalWrite {
+ // Implements SupportsSchemaEvolution.table.
+ // Implements TransactionalWrite.table.
override val table: LogicalPlan = EliminateSubqueryAliases(targetTable)
override def withNewTable(newTable: NamedRelation): MergeIntoTable = {
@@ -1270,6 +1279,16 @@ case class Assignment(key: Expression, value: Expression) extends Expression
newLeft: Expression, newRight: Expression): Assignment = copy(key = newLeft, value = newRight)
}
+/**
+ * Marker trait for write operations that participate in a DSv2 transaction.
+ *
+ * Implementations are expected to target a DSv2 catalog backed by a
+ * [[org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin]].
+ */
+trait TransactionalWrite extends LogicalPlan {
+ def table: LogicalPlan
+}
+
/**
* The logical plan of the DROP TABLE command.
*
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala
new file mode 100644
index 0000000000000..d160aafdea34e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.transactions
+
+import java.util.UUID
+
+import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin
+import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfoImpl}
+import org.apache.spark.util.Utils
+
+object TransactionUtils {
+ def commit(transaction: Transaction): Unit = {
+ Utils.tryWithSafeFinally {
+ transaction.commit()
+ } {
+ transaction.close()
+ }
+ }
+
+ def abort(transaction: Transaction): Unit = {
+ Utils.tryWithSafeFinally {
+ transaction.abort()
+ } {
+ transaction.close()
+ }
+ }
+
+ def beginTransaction(catalog: TransactionalCatalogPlugin): Transaction = {
+ val info = TransactionInfoImpl(id = UUID.randomUUID.toString)
+ val transaction = catalog.beginTransaction(info)
+ if (transaction.catalog.name != catalog.name) {
+ abort(transaction)
+ throw new IllegalStateException(
+ s"""Transaction catalog name (${transaction.catalog.name})
+ |must match original catalog name (${catalog.name}).
+ |""".stripMargin)
+ }
+ transaction
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala
index 8663bb65b6a88..152334fdfc600 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala
@@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, TempVariableManager}
import org.apache.spark.sql.catalyst.util.StringUtils
+import org.apache.spark.sql.connector.catalog.transactions.Transaction
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
@@ -39,7 +40,7 @@ import org.apache.spark.sql.internal.SQLConf
// need to track current database at all.
private[sql]
class CatalogManager(
- defaultSessionCatalog: CatalogPlugin,
+ val defaultSessionCatalog: CatalogPlugin,
val v1SessionCatalog: SessionCatalog) extends SQLConfHelper with Logging {
import CatalogManager.SESSION_CATALOG_NAME
import CatalogV2Util._
@@ -57,6 +58,11 @@ class CatalogManager(
}
}
+ def transaction: Option[Transaction] = None
+
+ def withTransaction(transaction: Transaction): CatalogManager =
+ new TransactionAwareCatalogManager(this, transaction)
+
def isCatalogRegistered(name: String): Boolean = {
try {
catalog(name)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala
index 203cfc23452a8..dd5be45bfc5f5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.connector.catalog
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedIdentifier, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.plans.logical.TransactionalWrite
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
@@ -163,4 +165,17 @@ private[sql] trait LookupCatalog extends Logging {
}
}
}
+
+ object TransactionalWrite {
+ def unapply(write: TransactionalWrite): Option[TransactionalCatalogPlugin] = {
+ EliminateSubqueryAliases(write.table) match {
+ case UnresolvedRelation(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _, _) =>
+ Some(c)
+ case UnresolvedIdentifier(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _) =>
+ Some(c)
+ case _ =>
+ None
+ }
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala
new file mode 100644
index 0000000000000..aaeef4c2dea76
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import org.apache.spark.sql.catalyst.catalog.TempVariableManager
+import org.apache.spark.sql.connector.catalog.transactions.Transaction
+
+/**
+ * A [[CatalogManager]] decorator that redirects catalog lookups to the transaction's catalog
+ * instance when names match, ensuring table loads during analysis are scoped to the transaction.
+ * All mutable state (current catalog, current namespace, loaded catalogs) is delegated to the
+ * wrapped [[CatalogManager]].
+ */
+// TODO: Extracting a CatalogManager trait (so this class can implement it instead of extending
+// CatalogManager) would eliminate the inherited mutable state that this decorator doesn't use.
+private[sql] class TransactionAwareCatalogManager(
+ delegate: CatalogManager,
+ txn: Transaction)
+ extends CatalogManager(delegate.defaultSessionCatalog, delegate.v1SessionCatalog) {
+
+ override val tempVariableManager: TempVariableManager = delegate.tempVariableManager
+
+ override def transaction: Option[Transaction] = Some(txn)
+
+ override def catalog(name: String): CatalogPlugin = {
+ val resolved = delegate.catalog(name)
+ if (txn.catalog.name() == resolved.name()) txn.catalog else resolved
+ }
+
+ override def currentCatalog: CatalogPlugin = {
+ val c = delegate.currentCatalog
+ if (txn.catalog.name() == c.name()) txn.catalog else c
+ }
+
+ override def currentNamespace: Array[String] = delegate.currentNamespace
+
+ override def setCurrentNamespace(namespace: Array[String]): Unit =
+ delegate.setCurrentNamespace(namespace)
+
+ override def setCurrentCatalog(catalogName: String): Unit =
+ delegate.setCurrentCatalog(catalogName)
+
+ override def listCatalogs(pattern: Option[String]): Seq[String] =
+ delegate.listCatalogs(pattern)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala
new file mode 100644
index 0000000000000..4cb53da0a59e2
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.transactions
+
+case class TransactionInfoImpl(id: String) extends TransactionInfo
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala
new file mode 100644
index 0000000000000..d409316e667b1
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.transactions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.connector.catalog.{CatalogPlugin, TransactionalCatalogPlugin}
+import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class TransactionUtilsSuite extends SparkFunSuite {
+ val testCatalogName = "test_catalog"
+
+ // --- Helpers ---------------------------------------------------------------
+ private def mockCatalog(catalogName: String): CatalogPlugin = new CatalogPlugin {
+ override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = ()
+ override def name(): String = catalogName
+ }
+
+ private val emptyFunction = () => ()
+ private class TestTransaction(
+ catalogName: String,
+ onCommit: () => Unit = emptyFunction,
+ onAbort: () => Unit = emptyFunction,
+ onClose: () => Unit = emptyFunction) extends Transaction {
+ var committed = false
+ var aborted = false
+ var closed = false
+
+ override def catalog(): CatalogPlugin = mockCatalog(catalogName)
+ override def commit(): Unit = { committed = true; onCommit() }
+ override def abort(): Unit = { aborted = true; onAbort() }
+ override def close(): Unit = { closed = true; onClose() }
+ }
+
+ private def mockTransactionalCatalog(
+ catalogName: String,
+ txnCatalogName: String = null): TransactionalCatalogPlugin = {
+ val resolvedTxnCatalogName = Option(txnCatalogName).getOrElse(catalogName)
+ new TransactionalCatalogPlugin {
+ override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = ()
+ override def name(): String = catalogName
+ override def beginTransaction(info: TransactionInfo): Transaction =
+ new TestTransaction(resolvedTxnCatalogName)
+ }
+ }
+
+ // --- Commit ----------------------------------------------------------------
+ test("commit: calls commit then close") {
+ val txn = new TestTransaction(testCatalogName)
+ TransactionUtils.commit(txn)
+ assert(txn.committed)
+ assert(txn.closed)
+ }
+
+ test("commit: close is called even if commit fails") {
+ val txn = new TestTransaction(
+ testCatalogName, onCommit = () => throw new RuntimeException("commit failed"))
+ intercept[RuntimeException] { TransactionUtils.commit(txn) }
+ assert(txn.closed)
+ }
+
+ // --- Abort -----------------------------------------------------------------
+ test("abort: calls abort then close") {
+ val txn = new TestTransaction(testCatalogName)
+ TransactionUtils.abort(txn)
+ assert(txn.aborted)
+ assert(txn.closed)
+ }
+
+ test("abort: close is called even if abort fails") {
+ val txn = new TestTransaction(testCatalogName,
+ onAbort = () => throw new RuntimeException("abort failed"))
+ intercept[RuntimeException] { TransactionUtils.abort(txn) }
+ assert(txn.closed)
+ }
+
+ // --- Begin Transaction -----------------------------------------------------
+ test("beginTransaction: returns transaction when catalog names match") {
+ val catalog = mockTransactionalCatalog(testCatalogName)
+ val txn = TransactionUtils.beginTransaction(catalog)
+ assert(txn.catalog().name() == testCatalogName)
+ }
+
+ test("beginTransaction: fails when transaction catalog name does not match") {
+ val catalog = mockTransactionalCatalog(catalogName = testCatalogName, txnCatalogName = "other")
+ val e = intercept[IllegalStateException] {
+ TransactionUtils.beginTransaction(catalog)
+ }
+ assert(e.getMessage.contains("other"))
+ assert(e.getMessage.contains(testCatalogName))
+ }
+
+ test("beginTransaction: aborts and closes transaction on catalog name mismatch") {
+ var aborted = false
+ var closed = false
+ val catalog = new TransactionalCatalogPlugin {
+ override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = ()
+ override def name(): String = testCatalogName
+ override def beginTransaction(info: TransactionInfo): Transaction =
+ new TestTransaction(
+ "other",
+ onAbort = () => { aborted = true },
+ onClose = () => { closed = true })
+ }
+ intercept[IllegalStateException] { TransactionUtils.beginTransaction(catalog) }
+ assert(aborted)
+ assert(closed)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index 8791677999810..af46ac006f3c9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -76,6 +76,10 @@ abstract class InMemoryBaseTable(
override def columns(): Array[Column] = tableColumns
+ private[catalog] def updateColumns(newColumns: Array[Column]): Unit = {
+ tableColumns = newColumns
+ }
+
override def version(): String = tableVersion.toString
def setVersion(version: String): Unit = {
@@ -94,6 +98,8 @@ abstract class InMemoryBaseTable(
validatedTableVersion = version
}
+ protected def recordScanEvent(filters: Array[Filter]): Unit = {}
+
protected object PartitionKeyColumn extends MetadataColumn {
override def name: String = "_partition"
override def dataType: DataType = StringType
@@ -455,6 +461,7 @@ abstract class InMemoryBaseTable(
if (evaluableFilters.nonEmpty) {
scan.filter(evaluableFilters)
}
+ recordScanEvent(_pushedFilters)
scan
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
index bbb9041bab37c..4e5e1e7c8c6e9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
@@ -17,11 +17,31 @@
package org.apache.spark.sql.connector.catalog
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
+import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo}
+import org.apache.spark.sql.types.StructType
-class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog {
+class InMemoryRowLevelOperationTableCatalog
+ extends InMemoryTableCatalog
+ with TransactionalCatalogPlugin {
import CatalogV2Implicits._
+ // The current active transaction.
+ var transaction: Txn = _
+ // The last completed transaction.
+ var lastTransaction: Txn = _
+ // All transactions in order (committed and aborted), allowing per-statement
+ // validation in SQL scripting tests.
+ val seenTransactions: ArrayBuffer[Txn] = new ArrayBuffer[Txn]()
+
+ override def beginTransaction(info: TransactionInfo): Transaction = {
+ assert(transaction == null || transaction.currentState != Active)
+ this.transaction = new Txn(new TxnTableCatalog(this))
+ transaction
+ }
+
override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
if (tables.containsKey(ident)) {
throw new TableAlreadyExistsException(ident.asMultipartIdentifier)
@@ -41,11 +61,7 @@ class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog {
override def alterTable(ident: Identifier, changes: TableChange*): Table = {
val table = loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable]
val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes)
- val schema = CatalogV2Util.applySchemaChanges(
- table.schema,
- changes,
- tableProvider = Some("in-memory"),
- statementType = "ALTER TABLE")
+ val schema = computeAlterTableSchema(table.schema, changes.toSeq)
val partitioning = CatalogV2Util.applyClusterByChanges(table.partitioning, schema, changes)
val constraints = CatalogV2Util.collectConstraintChanges(table, changes)
@@ -66,6 +82,16 @@ class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog {
newTable
}
+
+ /**
+ * Computes the schema that would result from applying `changes` to `currentSchema`.
+ * Can be overridden by subclasses to simulate catalogs that selectively ignore changes
+ * (e.g. [[PartialSchemaEvolutionCatalog]]).
+ */
+ def computeAlterTableSchema(currentSchema: StructType, changes: Seq[TableChange]): StructType = {
+ CatalogV2Util.applySchemaChanges(
+ currentSchema, changes, tableProvider = Some("in-memory"), statementType = "ALTER TABLE")
+ }
}
/**
@@ -84,9 +110,10 @@ class PartialSchemaEvolutionCatalog extends InMemoryRowLevelOperationTableCatalo
case _ => false
}
val properties = CatalogV2Util.applyPropertiesChanges(table.properties, propertyChanges)
+ val schema = computeAlterTableSchema(table.schema, changes.toSeq)
val newTable = new InMemoryRowLevelOperationTable(
name = table.name,
- schema = table.schema,
+ schema = schema,
partitioning = table.partitioning,
properties = properties,
constraints = table.constraints)
@@ -94,4 +121,9 @@ class PartialSchemaEvolutionCatalog extends InMemoryRowLevelOperationTableCatalo
tables.put(ident, newTable)
newTable
}
+
+ // Ignores all schema changes and returns the current schema unchanged.
+ override def computeAlterTableSchema(
+ currentSchema: StructType,
+ changes: Seq[TableChange]): StructType = currentSchema
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
index d5738475031dc..15ed4136dbda8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{LongType, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._
/**
@@ -215,6 +216,15 @@ class InMemoryTable(
object InMemoryTable {
+ // Convert UTF8String to string to make sure equality checks between filters and partitions
+ // work correctly.
+ private def valuesEqual(filterValue: Any, partitionValue: Any): Boolean =
+ (filterValue, partitionValue) match {
+ case (s: String, u: UTF8String) => u.toString == s
+ case (u: UTF8String, s: String) => u.toString == s
+ case _ => filterValue == partitionValue
+ }
+
def filtersToKeys(
keys: Iterable[Seq[Any]],
partitionNames: Seq[String],
@@ -222,7 +232,7 @@ object InMemoryTable {
keys.filter { partValues =>
filters.flatMap(splitAnd).forall {
case EqualTo(attr, value) =>
- value == InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
+ valuesEqual(value, InMemoryBaseTable.extractValue(attr, partitionNames, partValues))
case EqualNullSafe(attr, value) =>
val attrVal = InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
if (attrVal == null && value == null) {
@@ -230,7 +240,7 @@ object InMemoryTable {
} else if (attrVal == null || value == null) {
false
} else {
- value == attrVal
+ valuesEqual(value, attrVal)
}
case IsNull(attr) =>
null == InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala
new file mode 100644
index 0000000000000..232b623174996
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.connector.catalog.transactions.Transaction
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+sealed trait TransactionState
+case object Active extends TransactionState
+case object Committed extends TransactionState
+case object Aborted extends TransactionState
+
+class Txn(override val catalog: TxnTableCatalog) extends Transaction {
+
+ private[this] var state: TransactionState = Active
+ private[this] var closed: Boolean = false
+
+ def currentState: TransactionState = state
+
+ def isClosed: Boolean = closed
+
+ override def commit(): Unit = {
+ if (closed) throw new IllegalStateException("Can't commit, already closed")
+ if (state == Aborted) throw new IllegalStateException("Can't commit, already aborted")
+ catalog.commit()
+ this.state = Committed
+ }
+
+ // This is idempotent since nested QEs can cause multiple aborts.
+ override def abort(): Unit = {
+ if (state == Committed || state == Aborted) return
+ this.state = Aborted
+ }
+
+ // This is idempotent since nested QEs can cause multiple aborts.
+ override def close(): Unit = {
+ if (!closed) {
+ catalog.clearActiveTransaction()
+ this.closed = true
+ }
+ }
+}
+
+// A special table used in row-level operation transactions. It inherits data
+// from the base table upon construction and propagates staged transaction state
+// back after an explicit commit.
+// Note, the in-memory data store does not handle concurrency at the moment. The assumes that the
+// underlying delegate table cannot change from concurrent transactions. Data sources need to
+// implement isolation semantics and make sure they are enforced.
+class TxnTable(val delegate: InMemoryRowLevelOperationTable, schema: StructType)
+ extends InMemoryRowLevelOperationTable(
+ delegate.name,
+ schema,
+ delegate.partitioning,
+ delegate.properties,
+ delegate.constraints) {
+
+ alterTableWithData(delegate.data, schema)
+
+ // Keep initial version to detect any changes during the transaction.
+ private val initialVersion: String = version()
+
+ // A tracker of filters used in each scan.
+ val scanEvents = new ArrayBuffer[Array[Filter]]()
+
+ // Record scan events. This is invoked when building a scan for the particular table.
+ override protected def recordScanEvent(filters: Array[Filter]): Unit = {
+ scanEvents += filters
+ }
+
+ // Perform commit if there are any changes. This push metadata and data changes to the
+ // delegate table.
+ def commit(): Unit = {
+ if (version() != initialVersion) {
+ delegate.dataMap.clear()
+ delegate.updateColumns(columns()) // Evolve schema if needed.
+ delegate.alterTableWithData(data, schema)
+ delegate.replacedPartitions = replacedPartitions
+ delegate.lastWriteInfo = lastWriteInfo
+ delegate.lastWriteLog = lastWriteLog
+ delegate.commits ++= commits
+ delegate.increaseVersion()
+ }
+ }
+}
+
+// A special table catalog used in row-level operation transactions. The lifecycle of this catalog
+// is tied to the transaction. A new catalog instance is created at the beginning of a transaction
+// and discarded at the end. The catalog is responsible for pinning all tables involved in the
+// transaction. Table changes are initially staged in memory and propagated only after an explicit
+// commit.
+class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends TableCatalog {
+
+ private val tables: util.Map[Identifier, TxnTable] = new ConcurrentHashMap[Identifier, TxnTable]()
+
+ override def name: String = delegate.name
+
+ override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {}
+
+ override def listTables(namespace: Array[String]): Array[Identifier] = {
+ throw new UnsupportedOperationException()
+ }
+
+ // This is where the table pinning logic should occur. In this implementation, a tables is loaded
+ // (pinned) the first time is accessed. All subsequent accesses should return the same pinned
+ // table.
+ override def loadTable(ident: Identifier): Table = {
+ tables.computeIfAbsent(ident, _ => {
+ val table = delegate.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable]
+ new TxnTable(table, table.schema())
+ })
+ }
+
+ override def alterTable(ident: Identifier, changes: TableChange*): Table = {
+ // AlterTable may be called by ResolveSchemaEvolution when schema evolution is enabled. Thus,
+ // it needs to be transactional. The schema changes are only propagated to the delegate at
+ // commit time.
+ //
+ // We delegate schema computation to the underlying catalog so that catalogs with special
+ // handling (e.g. PartialSchemaEvolutionCatalog) have the same behaviour inside a
+ // transaction.
+ val txnTable = tables.get(ident)
+ val schema = delegate.computeAlterTableSchema(txnTable.schema, changes.toSeq)
+
+ if (schema.fields.isEmpty) {
+ throw new IllegalArgumentException(s"Cannot drop all fields")
+ }
+
+ // TODO: We need to pass all tracked predicates to the new TXN table.
+ val newTxnTable = new TxnTable(txnTable.delegate, schema)
+ tables.put(ident, newTxnTable)
+ newTxnTable
+ }
+
+ // TODO: Currently not transactional. Should be revised when Atomic CTAS/RTAS is implemented.
+ override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
+ delegate.createTable(ident, tableInfo)
+ loadTable(ident)
+ }
+
+ // TODO: Currently not transactional. Should be revised when Atomic CTAS/RTAS is implemented.
+ override def dropTable(ident: Identifier): Boolean = {
+ tables.remove(ident)
+ delegate.dropTable(ident)
+ }
+
+ override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = {
+ throw new UnsupportedOperationException()
+ }
+
+ // Invoke commit for all tables participated in the transaction. If a table is read-only
+ // this is a no-op.
+ def commit(): Unit = {
+ tables.values.forEach(table => table.commit())
+ }
+
+ // Clear transaction context.
+ def clearActiveTransaction(): Unit = {
+ val txn = delegate.transaction
+ delegate.lastTransaction = txn
+ delegate.seenTransactions += txn
+ delegate.transaction = null
+ }
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case that: CatalogPlugin => this.name == that.name
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = name.hashCode()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index f08b561d6ef9a..e07d0147205a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -27,23 +27,25 @@ import scala.util.control.NonFatal
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
-import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row}
import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker}
-import org.apache.spark.sql.catalyst.analysis.{LazyExpression, NameParameterizedQuery, UnsupportedOperationChecker}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, LazyExpression, NameParameterizedQuery, UnsupportedOperationChecker}
import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union, WithCTE}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union, UnresolvedWith, WithCTE}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
+import org.apache.spark.sql.catalyst.transactions.TransactionUtils
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.classic.SparkSession
+import org.apache.spark.sql.connector.catalog.LookupCatalog
+import org.apache.spark.sql.connector.catalog.transactions.Transaction
import org.apache.spark.sql.execution.SQLExecution.EXECUTION_ROOT_ID_KEY
import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan}
import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan}
-import org.apache.spark.sql.execution.datasources.v2.V2TableRefreshUtil
+import org.apache.spark.sql.execution.datasources.v2.{TransactionalExec, V2TableRefreshUtil}
import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
@@ -69,7 +71,8 @@ class QueryExecution(
val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL,
val shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None,
val refreshPhaseEnabled: Boolean = true,
- val queryId: UUID = UUIDv7Generator.generate()) extends Logging {
+ val queryId: UUID = UUIDv7Generator.generate(),
+ val analyzerOpt: Option[Analyzer] = None) extends LookupCatalog {
val id: Long = QueryExecution.nextExecutionId
@@ -79,6 +82,8 @@ class QueryExecution(
// TODO: Move the planner an optimizer into here from SessionState.
protected def planner = sparkSession.sessionState.planner
+ protected val catalogManager = sparkSession.sessionState.catalogManager
+
/**
* Check whether the query represented by this QueryExecution is a SQL script.
* @return True if the query is a SQL script, False otherwise.
@@ -90,6 +95,46 @@ class QueryExecution(
logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyExpression])))
}
+
+ // 1. At the pre-Analyzed plan we look for nodes that implement the TransactionalWrite trait.
+ // When a plan contains such a node we initiate a transaction. Note, we should never start
+ // a transaction for operations that are not executed, e.g. EXPLAIN.
+ // 2. Create an analyzer clone with a transaction aware Catalog Manager. The latter is the
+ // narrow waist of all catalog accesses, and it is also the transaction context carrier.
+ // This is then passed to all rules during analysis that need to check the catalog. Rules
+ // that are specifically interested in transactionality can access the transaction directly
+ // from the Catalog Manager. The transaction catalog, is potentially the place where connectors
+ // should keep state about the reads (tables+predicates) that occurred during the transaction.
+ // 3. The analyzer instance is passed to nested Query Execution instances. These need to respect
+ // the open transaction instead of creating their own.
+ private lazy val transactionOpt: Option[Transaction] =
+ // Always inherit an active transaction from the outer analyzer, regardless of mode.
+ analyzerOpt.flatMap(_.catalogManager.transaction).orElse {
+ // Only begin a new transaction for outer QEs that lead to execution.
+ if (mode != CommandExecutionMode.SKIP) {
+ val catalog = logical match {
+ case UnresolvedWith(TransactionalWrite(c), _, _) => Some(c)
+ case TransactionalWrite(c) => Some(c)
+ case _ => None
+ }
+ catalog.map(TransactionUtils.beginTransaction)
+ } else {
+ None
+ }
+ }
+
+ // Per-query analyzer: uses a transaction-aware CatalogManager when a transaction is active,
+ // so that all catalog lookups and rule applications during analysis see the correct state
+ // without relying on thread-local context.
+ private lazy val analyzer: Analyzer = analyzerOpt.getOrElse {
+ transactionOpt match {
+ case Some(txn) =>
+ sparkSession.sessionState.analyzer.withCatalogManager(catalogManager.withTransaction(txn))
+ case None =>
+ sparkSession.sessionState.analyzer
+ }
+ }
+
def assertAnalyzed(): Unit = {
try {
analyzed
@@ -102,7 +147,7 @@ class QueryExecution(
}
}
- def assertSupported(): Unit = {
+ def assertSupported(): Unit = withAbortTransactionOnFailure {
if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) {
UnsupportedOperationChecker.checkForBatch(analyzed)
}
@@ -141,7 +186,7 @@ class QueryExecution(
try {
val plan = executePhase(QueryPlanningTracker.ANALYSIS) {
// We can't clone `logical` here, which will reset the `_analyzed` flag.
- sparkSession.sessionState.analyzer.executeAndCheck(sqlScriptExecuted, tracker)
+ analyzer.executeAndCheck(sqlScriptExecuted, tracker)
}
tracker.setAnalyzed(plan)
plan
@@ -152,7 +197,9 @@ class QueryExecution(
}
}
- def analyzed: LogicalPlan = lazyAnalyzed.get
+ def analyzed: LogicalPlan = withAbortTransactionOnFailure {
+ lazyAnalyzed.get
+ }
private val lazyCommandExecuted = LazyTry {
mode match {
@@ -162,7 +209,9 @@ class QueryExecution(
}
}
- def commandExecuted: LogicalPlan = lazyCommandExecuted.get
+ def commandExecuted: LogicalPlan = withAbortTransactionOnFailure {
+ lazyCommandExecuted.get
+ }
private def commandExecutionName(command: Command): String = command match {
case _: CreateTableAsSelect => "create"
@@ -184,7 +233,7 @@ class QueryExecution(
// for eagerly executed commands we mark this place as beginning of execution.
tracker.setReadyForExecution()
val (qe, result) = QueryExecution.runCommand(
- sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode))
+ sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode), Some(analyzer))
CommandResult(
qe.analyzed.output,
qe.commandExecuted,
@@ -222,19 +271,31 @@ class QueryExecution(
}
// The plan that has been normalized by custom rules, so that it's more likely to hit cache.
- def normalized: LogicalPlan = lazyNormalized.get
+ def normalized: LogicalPlan = withAbortTransactionOnFailure {
+ lazyNormalized.get
+ }
private val lazyWithCachedData = LazyTry {
sparkSession.withActive {
assertAnalyzed()
assertSupported()
- // clone the plan to avoid sharing the plan instance between different stages like analyzing,
- // optimizing and planning.
- sparkSession.sharedState.cacheManager.useCachedData(normalized.clone())
+
+ // During a transaction, skip cache substitution. This is to avoid replacing relations
+ // loaded by the transactional catalog with potentially stale relations cached before
+ // the transaction was active.
+ if (transactionOpt.isDefined) {
+ normalized
+ } else {
+ // Clone the plan to avoid sharing the plan instance between different stages like
+ // analyzing, optimizing and planning.
+ sparkSession.sharedState.cacheManager.useCachedData(normalized.clone())
+ }
}
}
- def withCachedData: LogicalPlan = lazyWithCachedData.get
+ def withCachedData: LogicalPlan = withAbortTransactionOnFailure {
+ lazyWithCachedData.get
+ }
def assertCommandExecuted(): Unit = commandExecuted
@@ -256,7 +317,9 @@ class QueryExecution(
}
}
- def optimizedPlan: LogicalPlan = lazyOptimizedPlan.get
+ def optimizedPlan: LogicalPlan = withAbortTransactionOnFailure {
+ lazyOptimizedPlan.get
+ }
def assertOptimized(): Unit = optimizedPlan
@@ -264,14 +327,17 @@ class QueryExecution(
// We need to materialize the optimizedPlan here because sparkPlan is also tracked under
// the planning phase
assertOptimized()
- executePhase(QueryPlanningTracker.PLANNING) {
+ val plan = executePhase(QueryPlanningTracker.PLANNING) {
// Clone the logical plan here, in case the planner rules change the states of the logical
// plan.
QueryExecution.createSparkPlan(planner, optimizedPlan.clone())
}
+ attachTransaction(plan)
}
- def sparkPlan: SparkPlan = lazySparkPlan.get
+ def sparkPlan: SparkPlan = withAbortTransactionOnFailure {
+ lazySparkPlan.get
+ }
def assertSparkPlanPrepared(): Unit = sparkPlan
@@ -292,7 +358,9 @@ class QueryExecution(
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
- def executedPlan: SparkPlan = lazyExecutedPlan.get
+ def executedPlan: SparkPlan = withAbortTransactionOnFailure {
+ lazyExecutedPlan.get
+ }
def assertExecutedPlanPrepared(): Unit = executedPlan
@@ -310,7 +378,9 @@ class QueryExecution(
* Given QueryExecution is not a public class, end users are discouraged to use this: please
* use `Dataset.rdd` instead where conversion will be applied.
*/
- def toRdd: RDD[InternalRow] = lazyToRdd.get
+ def toRdd: RDD[InternalRow] = withAbortTransactionOnFailure {
+ lazyToRdd.get
+ }
private val observedMetricsLock = new Object
@@ -512,6 +582,26 @@ class QueryExecution(
}
}
+ /**
+ * Runs the given block, aborting the active transaction if an exception is thrown.
+ * If no transaction is active, the block is executed as-is.
+ */
+ private def withAbortTransactionOnFailure[T](block: => T): T = transactionOpt match {
+ case Some(transaction) =>
+ try block
+ catch { case e: Throwable => TransactionUtils.abort(transaction); throw e }
+ case None => block
+ }
+
+
+ /** Attaches a transaction to the given SparkPlan to the transactional execution nodes. */
+ private def attachTransaction(plan: SparkPlan): SparkPlan = transactionOpt match {
+ case Some(txn) => plan.transformDown {
+ case w: TransactionalExec => w.withTransaction(Some(txn))
+ }
+ case None => plan
+ }
+
/** A special namespace for commands that can be used to debug query execution. */
// scalastyle:off
object debug {
@@ -796,14 +886,16 @@ object QueryExecution {
name: String,
refreshPhaseEnabled: Boolean = true,
mode: CommandExecutionMode.Value = CommandExecutionMode.SKIP,
- shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None)
+ shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None,
+ analyzerOpt: Option[Analyzer] = None)
: (QueryExecution, Array[InternalRow]) = {
val qe = new QueryExecution(
sparkSession,
command,
mode = mode,
shuffleCleanupModeOpt = shuffleCleanupModeOpt,
- refreshPhaseEnabled = refreshPhaseEnabled)
+ refreshPhaseEnabled = refreshPhaseEnabled,
+ analyzerOpt = analyzerOpt)
val result = QueryExecution.withInternalError(s"Executed $name failed.") {
SQLExecution.withNewExecutionId(qe, Some(name)) {
qe.executedPlan.executeCollect()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala
index 8d5ee6038e80f..c6b1bae89b156 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala
@@ -19,16 +19,23 @@ package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.transactions.TransactionUtils
import org.apache.spark.sql.connector.catalog.SupportsDeleteV2
+import org.apache.spark.sql.connector.catalog.transactions.Transaction
import org.apache.spark.sql.connector.expressions.filter.Predicate
case class DeleteFromTableExec(
table: SupportsDeleteV2,
condition: Array[Predicate],
- refreshCache: () => Unit) extends LeafV2CommandExec {
+ refreshCache: () => Unit,
+ transaction: Option[Transaction] = None) extends LeafV2CommandExec with TransactionalExec {
+
+ override def withTransaction(txn: Option[Transaction]): DeleteFromTableExec =
+ copy(transaction = txn)
override protected def run(): Seq[InternalRow] = {
table.deleteWhere(condition)
+ transaction.foreach(TransactionUtils.commit)
refreshCache()
Seq.empty
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 804e694f92e9c..07876887788be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -26,9 +26,11 @@ import org.apache.spark.sql.catalyst.{InternalRow, ProjectingInternalRow}
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, TableSpec, UnaryNode}
+import org.apache.spark.sql.catalyst.transactions.TransactionUtils
import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, ReplaceDataProjections, WriteDeltaProjections}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege}
+import org.apache.spark.sql.connector.catalog.transactions.Transaction
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, Write, WriterCommitMessage, WriteSummary}
@@ -73,7 +75,12 @@ case class CreateTableAsSelectExec(
query: LogicalPlan,
tableSpec: TableSpec,
writeOptions: Map[String, String],
- ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec {
+ ifNotExists: Boolean,
+ transaction: Option[Transaction] = None)
+ extends V2CreateTableAsSelectBaseExec with TransactionalExec {
+
+ override def withTransaction(txn: Option[Transaction]): CreateTableAsSelectExec =
+ copy(transaction = txn)
val properties = CatalogV2Util.convertTableProperties(tableSpec)
@@ -91,7 +98,9 @@ case class CreateTableAsSelectExec(
.build()
val table = Option(catalog.createTable(ident, tableInfo))
.getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava))
- writeToTable(catalog, table, writeOptions, ident, query, overwrite = false)
+ val result = writeToTable(catalog, table, writeOptions, ident, query, overwrite = false)
+ transaction.foreach(TransactionUtils.commit)
+ result
}
}
@@ -111,7 +120,8 @@ case class AtomicCreateTableAsSelectExec(
query: LogicalPlan,
tableSpec: TableSpec,
writeOptions: Map[String, String],
- ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec {
+ ifNotExists: Boolean)
+ extends V2CreateTableAsSelectBaseExec {
val properties = CatalogV2Util.convertTableProperties(tableSpec)
@@ -154,8 +164,12 @@ case class ReplaceTableAsSelectExec(
tableSpec: TableSpec,
writeOptions: Map[String, String],
orCreate: Boolean,
- invalidateCache: (TableCatalog, Identifier) => Unit)
- extends V2CreateTableAsSelectBaseExec {
+ invalidateCache: (TableCatalog, Identifier) => Unit,
+ transaction: Option[Transaction] = None)
+ extends V2CreateTableAsSelectBaseExec with TransactionalExec {
+
+ override def withTransaction(txn: Option[Transaction]): ReplaceTableAsSelectExec =
+ copy(transaction = txn)
val properties = CatalogV2Util.convertTableProperties(tableSpec)
@@ -191,9 +205,11 @@ case class ReplaceTableAsSelectExec(
.build()
val table = Option(catalog.createTable(ident, tableInfo))
.getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava))
- writeToTable(
+ val result = writeToTable(
catalog, table, writeOptions, ident, refreshedQuery,
overwrite = true, refreshPhaseEnabled = false)
+ transaction.foreach(TransactionUtils.commit)
+ result
}
}
@@ -271,7 +287,9 @@ case class AtomicReplaceTableAsSelectExec(
case class AppendDataExec(
query: SparkPlan,
refreshCache: () => Unit,
- write: Write) extends V2ExistingTableWriteExec {
+ write: Write,
+ transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec {
+ override def withTransaction(txn: Option[Transaction]): AppendDataExec = copy(transaction = txn)
override protected def withNewChildInternal(newChild: SparkPlan): AppendDataExec =
copy(query = newChild)
}
@@ -289,7 +307,10 @@ case class AppendDataExec(
case class OverwriteByExpressionExec(
query: SparkPlan,
refreshCache: () => Unit,
- write: Write) extends V2ExistingTableWriteExec {
+ write: Write,
+ transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec {
+ override def withTransaction(txn: Option[Transaction]): OverwriteByExpressionExec =
+ copy(transaction = txn)
override protected def withNewChildInternal(newChild: SparkPlan): OverwriteByExpressionExec =
copy(query = newChild)
}
@@ -306,7 +327,10 @@ case class OverwriteByExpressionExec(
case class OverwritePartitionsDynamicExec(
query: SparkPlan,
refreshCache: () => Unit,
- write: Write) extends V2ExistingTableWriteExec {
+ write: Write,
+ transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec {
+ override def withTransaction(txn: Option[Transaction]): OverwritePartitionsDynamicExec =
+ copy(transaction = txn)
override protected def withNewChildInternal(newChild: SparkPlan): OverwritePartitionsDynamicExec =
copy(query = newChild)
}
@@ -318,7 +342,8 @@ case class ReplaceDataExec(
query: SparkPlan,
refreshCache: () => Unit,
projections: ReplaceDataProjections,
- write: Write) extends V2ExistingTableWriteExec {
+ write: Write,
+ transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec {
override def writingTask: WritingSparkTask[_] = {
projections match {
@@ -329,6 +354,7 @@ case class ReplaceDataExec(
}
}
+ override def withTransaction(txn: Option[Transaction]): ReplaceDataExec = copy(transaction = txn)
override protected def withNewChildInternal(newChild: SparkPlan): ReplaceDataExec = {
copy(query = newChild)
}
@@ -341,7 +367,8 @@ case class WriteDeltaExec(
query: SparkPlan,
refreshCache: () => Unit,
projections: WriteDeltaProjections,
- write: DeltaWrite) extends V2ExistingTableWriteExec {
+ write: DeltaWrite,
+ transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec {
override lazy val writingTask: WritingSparkTask[_] = {
if (projections.metadataProjection.isDefined) {
@@ -351,6 +378,7 @@ case class WriteDeltaExec(
}
}
+ override def withTransaction(txn: Option[Transaction]): WriteDeltaExec = copy(transaction = txn)
override protected def withNewChildInternal(newChild: SparkPlan): WriteDeltaExec = {
copy(query = newChild)
}
@@ -378,7 +406,16 @@ case class WriteToDataSourceV2Exec(
copy(query = newChild)
}
-trait V2ExistingTableWriteExec extends V2TableWriteExec {
+/**
+ * Trait for physical plan nodes that write to an existing table as part of a transaction.
+ * The [[transaction]] is injected post-planning by [[QueryExecution]].
+ */
+trait TransactionalExec extends SparkPlan {
+ def transaction: Option[Transaction]
+ def withTransaction(txn: Option[Transaction]): SparkPlan
+}
+
+trait V2ExistingTableWriteExec extends V2TableWriteExec with TransactionalExec {
def refreshCache: () => Unit
def write: Write
@@ -395,6 +432,7 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec {
} finally {
postDriverMetrics()
}
+ transaction.foreach(TransactionUtils.commit)
refreshCache()
writtenRows
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala
new file mode 100644
index 0000000000000..aef9c65550fc4
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala
@@ -0,0 +1,500 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.connector.catalog.{Aborted, Committed}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
+import org.apache.spark.sql.sources
+
+class AppendDataTransactionSuite extends RowLevelOperationSuiteBase {
+
+ test("writeTo append with transactional checks") {
+ // create table with initial data
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ // create a source on top of itself that will be fully resolved and analyzed
+ val sourceDF = spark.table(tableNameAsString)
+ .where("pk == 1")
+ .select(col("pk") + 10 as "pk", col("salary"), col("dep"))
+ sourceDF.queryExecution.assertAnalyzed()
+
+ // append data using the DataFrame API
+ val (txn, txnTables) = executeTransaction {
+ sourceDF.writeTo(tableNameAsString).append()
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size === 1)
+ assert(table.version() === "2")
+
+ // check the source scan was tracked via the transaction catalog
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.size === 1)
+ assert(targetTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("pk", 1) => true
+ case _ => false
+ })
+
+ // check data was appended correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software"),
+ Row(11, 100, "hr"))) // appended
+ }
+
+ test("SQL INSERT INTO with transactional checks") {
+ // create table with initial data
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ // SQL INSERT INTO using VALUES
+ val (txn, txnTables) = executeTransaction {
+ sql(s"INSERT INTO $tableNameAsString VALUES (3, 300, 'hr'), (4, 400, 'finance')")
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(table.version() === "2")
+
+ // VALUES literal - No catalog tables were scanned
+ assert(txnTables.isEmpty)
+
+ // check data was inserted correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software"),
+ Row(3, 300, "hr"),
+ Row(4, 400, "finance")))
+ }
+
+ for (isDynamic <- Seq(false, true))
+ test(s"SQL INSERT OVERWRITE with transactional checks - isDynamic: $isDynamic") {
+ // create table with initial data; table is partitioned by dep
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ val insertOverwrite = if (isDynamic) {
+ // OverwritePartitionsDynamic
+ s"""INSERT OVERWRITE $tableNameAsString
+ |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin
+ } else {
+ // OverwriteByExpression
+ s"""INSERT OVERWRITE $tableNameAsString
+ |PARTITION (dep = 'hr')
+ |SELECT pk + 10, salary FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin
+ }
+
+ val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC
+ val (txn, txnTables) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) {
+ executeTransaction { sql(insertOverwrite) }
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(table.version() === "2")
+
+ // the SELECT reads from the target table once with a dep='hr' filter
+ assert(txnTables.size == 1)
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.size == 1)
+ assert(targetTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(11, 100, "hr"), // overwritten
+ Row(13, 300, "hr"))) // overwritten
+ }
+
+ test("writeTo overwrite with transactional checks") {
+ // create table with initial data; table is partitioned by dep
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // overwrite using a condition that covers the hr partition -> OverwriteByExpression
+ val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))).
+ toDF("pk", "salary", "dep")
+
+ val (txn, txnTables) = executeTransaction {
+ sourceDF.writeTo(tableNameAsString).overwrite(col("dep") === "hr")
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(table.version() === "2")
+
+ // literal DataFrame source - no catalog tables were scanned
+ assert(txnTables.isEmpty)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(11, 999, "hr"), // overwrote hr partition
+ Row(12, 888, "hr"))) // overwrote hr partition
+ }
+
+ test("writeTo overwritePartitions with transactional checks") {
+ // create table with initial data; table is partitioned by dep
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // overwrite partitions dynamically -> OverwritePartitionsDynamic
+ val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))).
+ toDF("pk", "salary", "dep")
+
+ val (txn, txnTables) = executeTransaction {
+ sourceDF.writeTo(tableNameAsString).overwritePartitions()
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(table.version() === "2")
+
+ // literal DataFrame source - no catalog tables were scanned
+ assert(txnTables.isEmpty)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(11, 999, "hr"), // overwrote hr partition
+ Row(12, 888, "hr"))) // overwrote hr partition
+ }
+
+ test("SQL INSERT INTO SELECT with transactional checks") {
+ // create table with initial data
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // SQL INSERT INTO using SELECT from the same table (self-insert)
+ val (txn, txnTables) = executeTransaction {
+ sql(s"""INSERT INTO $tableNameAsString
+ |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin)
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(table.version() === "2")
+
+ // the SELECT reads from the target table once with a dep='hr' filter
+ assert(txnTables.size === 1)
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.size === 1)
+ assert(targetTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ // check data was inserted correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software"),
+ Row(3, 300, "hr"),
+ Row(11, 100, "hr"), // inserted from pk=1
+ Row(13, 300, "hr"))) // inserted from pk=3
+ }
+
+ test("SQL INSERT INTO SELECT with subquery on source table and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 500, 'hr'), (3, 600, 'software')")
+
+ // INSERT using a subquery that reads from the target to filter source rows
+ // both tables are scanned through the transaction catalog
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""INSERT INTO $tableNameAsString
+ |SELECT pk + 10, salary, dep FROM $sourceNameAsString
+ |WHERE pk IN (SELECT pk FROM $tableNameAsString WHERE dep = 'hr')
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size === 2)
+ assert(table.version() === "2")
+
+ // target was scanned via the transaction catalog (IN subquery) once with dep='hr' filter
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.size === 1)
+ assert(targetTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ // source was scanned via the transaction catalog exactly once (no filter)
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ assert(sourceTxnTable.scanEvents.size === 1)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software"),
+ Row(11, 500, "hr"))) // inserted: source pk=1 matched target hr row
+ }
+
+ test("SQL INSERT INTO SELECT with CTE and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 500, 'hr'), (3, 600, 'software')")
+
+ // CTE reads from target; INSERT selects from source filtered by the CTE result
+ // both tables are scanned through the transaction catalog
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""WITH hr_pks AS (SELECT pk FROM $tableNameAsString WHERE dep = 'hr')
+ |INSERT INTO $tableNameAsString
+ |SELECT pk + 10, salary, dep FROM $sourceNameAsString
+ |WHERE pk IN (SELECT pk FROM hr_pks)
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size === 2)
+ assert(table.version() === "2")
+
+ // target was scanned via the transaction catalog (CTE) once with dep='hr' filter
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.size === 1)
+ assert(targetTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ // source was scanned via the transaction catalog exactly once (no filter)
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ assert(sourceTxnTable.scanEvents.size === 1)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software"),
+ Row(11, 500, "hr"))) // inserted: source pk=1 matched target hr row via CTE
+ }
+
+ test("SQL INSERT with analysis failure and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val e = intercept[AnalysisException] {
+ sql(s"INSERT INTO $tableNameAsString SELECT nonexistent_col FROM $tableNameAsString")
+ }
+
+ assert(e.getMessage.contains("nonexistent_col"))
+ assert(catalog.lastTransaction.currentState === Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ for (isDynamic <- Seq(false, true))
+ test(s"SQL INSERT OVERWRITE with analysis failure and transactional checks" +
+ s"isDynamic: $isDynamic") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val insertOverwrite = if (isDynamic) {
+ s"""INSERT OVERWRITE $tableNameAsString
+ |SELECT nonexistent_col, salary, dep FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin
+ } else {
+ s"""INSERT OVERWRITE $tableNameAsString
+ |PARTITION (dep = 'hr')
+ |SELECT nonexistent_col FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin
+ }
+
+ val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC
+ val e = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) {
+ intercept[AnalysisException] { sql(insertOverwrite) }
+ }
+
+ assert(e.getMessage.contains("nonexistent_col"))
+ assert(catalog.lastTransaction.currentState === Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ test("EXPLAIN INSERT SQL with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"EXPLAIN INSERT INTO $tableNameAsString VALUES (3, 300, 'hr')")
+
+ // EXPLAIN should not start a transaction
+ assert(catalog.transaction === null)
+
+ // INSERT was not executed; data is unchanged
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software")))
+ }
+
+ test("SQL INSERT WITH SCHEMA EVOLUTION adds new column with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ sql(
+ s"""CREATE TABLE $sourceNameAsString
+ |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin)
+ sql(s"INSERT INTO $sourceNameAsString VALUES (3, 300, 'hr', true), (4, 400, 'software', false)")
+
+ val (txn, txnTables) = executeTransaction {
+ sql(s"INSERT WITH SCHEMA EVOLUTION INTO $tableNameAsString SELECT * FROM $sourceNameAsString")
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+
+ // the new column must be visible in the committed delegate's schema
+ assert(table.schema.fieldNames.toSeq === Seq("pk", "salary", "dep", "active"))
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr", null), // pre-existing rows: active is null
+ Row(2, 200, "software", null),
+ Row(3, 300, "hr", true), // inserted with active
+ Row(4, 400, "software", false)))
+ }
+
+ for (isDynamic <- Seq(false, true))
+ test(s"SQL INSERT OVERWRITE WITH SCHEMA EVOLUTION adds new column with transactional checks " +
+ s"isDynamic: $isDynamic") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(
+ s"""CREATE TABLE $sourceNameAsString
+ |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin)
+ sql(s"INSERT INTO $sourceNameAsString VALUES (11, 999, 'hr', true), (12, 888, 'hr', false)")
+
+ val insertOverwrite = if (isDynamic) {
+ s"""INSERT WITH SCHEMA EVOLUTION OVERWRITE TABLE $tableNameAsString
+ |SELECT * FROM $sourceNameAsString
+ |""".stripMargin
+ } else {
+ s"""INSERT WITH SCHEMA EVOLUTION OVERWRITE TABLE $tableNameAsString
+ |PARTITION (dep = 'hr')
+ |SELECT pk, salary, active FROM $sourceNameAsString
+ |""".stripMargin
+ }
+
+ val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC
+ val (txn, _) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) {
+ executeTransaction { sql(insertOverwrite) }
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+
+ // the new column must be visible in the committed delegate's schema
+ assert(table.schema.fieldNames.contains("active"))
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software", null), // unchanged (different partition)
+ Row(11, 999, "hr", true), // overwrote hr partition
+ Row(12, 888, "hr", false)))
+ }
+
+ test("SQL INSERT WITH SCHEMA EVOLUTION analysis failure aborts transaction") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ sql(
+ s"""CREATE TABLE $sourceNameAsString
+ |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin)
+
+ val e = intercept[AnalysisException] {
+ sql(
+ s"""INSERT WITH SCHEMA EVOLUTION INTO $tableNameAsString
+ |SELECT nonexistent_col FROM $sourceNameAsString
+ |""".stripMargin)
+ }
+
+ assert(e.getMessage.contains("nonexistent_col"))
+ assert(catalog.lastTransaction.currentState === Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ // schema must be unchanged after the aborted transaction
+ assert(table.schema.fieldNames.toSeq === Seq("pk", "salary", "dep"))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala
new file mode 100644
index 0000000000000..8acdd8242ef1f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTable}
+import org.apache.spark.sql.sources
+
+class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase {
+
+ private val newTableNameAsString = "cat.ns1.new_table"
+
+ private def newTable: InMemoryRowLevelOperationTable =
+ catalog.loadTable(Identifier.of(Array("ns1"), "new_table"))
+ .asInstanceOf[InMemoryRowLevelOperationTable]
+
+ test("CTAS with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val (txn, txnTables) = executeTransaction {
+ sql(s"""CREATE TABLE $newTableNameAsString
+ |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size === 1)
+ assert(table.version() === "1") // source table: read-only, version unchanged
+ assert(newTable.version() === "1") // target table: newly created and written
+
+ // the source table was scanned once through the transaction catalog with a dep='hr' filter
+ val sourceTxnTable = txnTables(tableNameAsString)
+ assert(sourceTxnTable.scanEvents.size === 1)
+ assert(sourceTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ checkAnswer(
+ sql(s"SELECT * FROM $newTableNameAsString"),
+ Seq(Row(1, 100, "hr")))
+ }
+
+ test("CTAS from literal source with transactional checks") {
+ // no source catalog table involved — the query is a pure literal SELECT
+ val (txn, txnTables) = executeTransaction {
+ sql(s"CREATE TABLE $newTableNameAsString AS SELECT 1 AS pk, 100 AS salary, 'hr' AS dep")
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+
+ // literal SELECT - no catalog tables were scanned
+ assert(txnTables.isEmpty)
+ assert(newTable.version() === "1") // target table: newly created and written
+
+ checkAnswer(
+ sql(s"SELECT * FROM $newTableNameAsString"),
+ Seq(Row(1, 100, "hr")))
+ }
+
+ test("CTAS with analysis failure and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val e = intercept[AnalysisException] {
+ sql(s"CREATE TABLE $newTableNameAsString AS SELECT nonexistent_col FROM $tableNameAsString")
+ }
+
+ assert(e.getMessage.contains("nonexistent_col"))
+ assert(catalog.lastTransaction.currentState === Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ test("RTAS with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // pre-create the target so REPLACE TABLE (not CREATE OR REPLACE) is valid
+ sql(s"CREATE TABLE $newTableNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+
+ val (txn, txnTables) = executeTransaction {
+ sql(s"""REPLACE TABLE $newTableNameAsString
+ |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size === 1)
+ assert(table.version() === "1") // source table: read-only, version unchanged
+ assert(newTable.version() === "1") // target table: replaced and written
+
+ // the source table was scanned once through the transaction catalog with a dep='hr' filter
+ val sourceTxnTable = txnTables(tableNameAsString)
+ assert(sourceTxnTable.scanEvents.size === 1)
+ assert(sourceTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ checkAnswer(
+ sql(s"SELECT * FROM $newTableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(3, 300, "hr")))
+ }
+
+ test("RTAS self-reference with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // source and target are the same table: reads the old snapshot via TxnTable,
+ // replaces the table with a filtered version
+ val (txn, txnTables) = executeTransaction {
+ sql(s"""CREATE OR REPLACE TABLE $tableNameAsString
+ |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size === 1)
+ assert(table.version() === "1") // source/target table: replaced in place, version reset to 1
+
+ // the source/target table was scanned once with a dep='hr' filter
+ val sourceTxnTable = txnTables(tableNameAsString)
+ assert(sourceTxnTable.scanEvents.size === 1)
+ assert(sourceTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(3, 300, "hr")))
+ }
+
+ test("RTAS with analysis failure and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val e = intercept[AnalysisException] {
+ sql(s"""CREATE OR REPLACE TABLE $tableNameAsString
+ |AS SELECT nonexistent_col FROM $tableNameAsString
+ |""".stripMargin)
+ }
+
+ assert(e.getMessage.contains("nonexistent_col"))
+ assert(catalog.lastTransaction.currentState === Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ test("simple CREATE TABLE and DROP TABLE do not create transactions") {
+ sql(s"CREATE TABLE $newTableNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ assert(catalog.transaction === null)
+ assert(catalog.lastTransaction === null)
+
+ sql(s"DROP TABLE $newTableNameAsString")
+ assert(catalog.transaction === null)
+ assert(catalog.lastTransaction === null)
+ }
+
+ test("EXPLAIN CTAS with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"""EXPLAIN CREATE TABLE $newTableNameAsString
+ |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin)
+
+ // EXPLAIN should not start a transaction
+ assert(catalog.transaction === null)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala
index 30890200df79d..fbcfdfb20c6ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala
@@ -109,7 +109,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
collected = df.queryExecution.executedPlan.collect {
case CommandResultExec(
- _, AppendDataExec(_, _, write),
+ _, AppendDataExec(_, _, write, _),
_) =>
val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append]
assert(append.info.options.get("write.split-size") === "10")
@@ -141,7 +141,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
assert (collected.size == 1)
collected = qe.executedPlan.collect {
- case AppendDataExec(_, _, write) =>
+ case AppendDataExec(_, _, write, _) =>
val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append]
assert(append.info.options.get("write.split-size") === "10")
}
@@ -168,7 +168,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
assert (collected.size == 1)
collected = qe.executedPlan.collect {
- case AppendDataExec(_, _, write) =>
+ case AppendDataExec(_, _, write, _) =>
val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append]
assert(append.info.options.get("write.split-size") === "10")
}
@@ -194,7 +194,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
collected = df.queryExecution.executedPlan.collect {
case CommandResultExec(
- _, OverwriteByExpressionExec(_, _, write),
+ _, OverwriteByExpressionExec(_, _, write, _),
_) =>
val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend]
assert(append.info.options.get("write.split-size") === "10")
@@ -227,7 +227,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
assert (collected.size == 1)
collected = qe.executedPlan.collect {
- case OverwritePartitionsDynamicExec(_, _, write) =>
+ case OverwritePartitionsDynamicExec(_, _, write, _) =>
val dynOverwrite = write.toBatch.asInstanceOf[InMemoryBaseTable#DynamicOverwrite]
assert(dynOverwrite.info.options.get("write.split-size") === "10")
}
@@ -254,7 +254,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
collected = df.queryExecution.executedPlan.collect {
case CommandResultExec(
- _, OverwriteByExpressionExec(_, _, write),
+ _, OverwriteByExpressionExec(_, _, write, _),
_) =>
val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend]
assert(append.info.options.get("write.split-size") === "10")
@@ -287,7 +287,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
assert (collected.size == 1)
collected = qe.executedPlan.collect {
- case OverwriteByExpressionExec(_, _, write) =>
+ case OverwriteByExpressionExec(_, _, write, _) =>
val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend]
assert(append.info.options.get("write.split-size") === "10")
}
@@ -317,7 +317,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
assert (collected.size == 1)
collected = qe.executedPlan.collect {
- case OverwriteByExpressionExec(_, _, write) =>
+ case OverwriteByExpressionExec(_, _, write, _) =>
val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend]
assert(append.info.options.get("write.split-size") === "10")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala
index 0f7f4cefe2feb..378a3af2561b3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.connector
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.expressions.CheckInvariant
import org.apache.spark.sql.catalyst.plans.logical.Filter
+import org.apache.spark.sql.connector.catalog.{Aborted, Committed}
import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec, ReplaceDataExec, WriteDeltaExec}
+import org.apache.spark.sql.sources
abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase {
@@ -28,6 +30,10 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase {
protected def enforceCheckConstraintOnDelete: Boolean = true
+ // true when the table engine uses delta-based deletes (WriteDeltaExec), false for group-based
+ // (ReplaceDataExec); controls expected scan counts in transactional tests
+ protected def deltaDelete: Boolean = false
+
test("delete from table containing added column with default value") {
createAndInitTable("pk INT NOT NULL, dep STRING", """{ "pk": 1, "dep": "hr" }""")
@@ -659,6 +665,190 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase {
}
}
+ test("delete with analysis failure and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val exception = intercept[AnalysisException] {
+ sql(s"DELETE FROM $tableNameAsString WHERE invalid_column = 1")
+ }
+
+ assert(exception.getMessage.contains("invalid_column"))
+ assert(catalog.lastTransaction.currentState == Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ test("delete with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // simple predicate delete: goes through SupportsDelete.deleteWhere (no Spark-side scan)
+ val (txn, _) = executeTransaction {
+ sql(s"DELETE FROM $tableNameAsString WHERE dep = 'hr'")
+ }
+
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(table.version() == "2")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(2, 200, "software")))
+ }
+
+ test("delete with subquery on source table and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')")
+
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE pk IN (SELECT pk FROM $sourceNameAsString WHERE dep = 'hr')
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaDelete) 1 else 2
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ val numSubquerySourceScans = sourceTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ assert(numSubquerySourceScans == expectedNumSourceScans)
+
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaDelete) 1 else 2
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged (pk 3 not in subquery result)
+ }
+
+ test("delete with CTE and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')")
+
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""WITH cte AS (
+ | SELECT pk FROM $sourceNameAsString WHERE dep = 'hr'
+ |)
+ |DELETE FROM $tableNameAsString
+ |WHERE pk IN (SELECT pk FROM cte)
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaDelete) 1 else 2
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaDelete) 1 else 2
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ assert(numCteSourceScans == expectedNumSourceScans)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged (pk 3 not in source)
+ }
+
+ test("delete using view with transactional checks") {
+ withView("temp_view") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')")
+
+ sql(
+ s"""CREATE VIEW temp_view AS
+ |SELECT pk FROM $sourceNameAsString WHERE dep = 'hr'
+ |""".stripMargin)
+
+ val (txn, txnTables) = executeTransaction {
+ sql(s"DELETE FROM $tableNameAsString WHERE pk IN (SELECT pk FROM temp_view)")
+ }
+
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaDelete) 1 else 2
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaDelete) 1 else 2
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged (pk 3 not in source)
+ }
+ }
+
+ test("EXPLAIN DELETE SQL with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"EXPLAIN DELETE FROM $tableNameAsString WHERE dep = 'hr'")
+
+ // EXPLAIN should not start a new transaction
+ assert(catalog.transaction === null)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software")))
+ }
+
private def executeDeleteWithFilters(query: String): Unit = {
val executedPlan = executeAndKeepPlan {
sql(query)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala
index 9046123ddbd3f..56a689364d1c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala
@@ -33,6 +33,8 @@ class DeltaBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase {
override def enforceCheckConstraintOnDelete: Boolean = false
+ override protected def deltaDelete: Boolean = true
+
test("delete handles metadata columns correctly") {
createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
"""{ "pk": 1, "id": 1, "dep": "hr" }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala
index 89b42b5e6db7b..e821fc3f660da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala
@@ -21,6 +21,8 @@ import org.apache.spark.sql.{AnalysisException, Row}
abstract class DeltaBasedUpdateTableSuiteBase extends UpdateTableSuiteBase {
+ override protected def deltaUpdate: Boolean = true
+
test("nullable row ID attrs") {
createAndInitTable("pk INT, salary INT, dep STRING",
"""{ "pk": 1, "salary": 300, "dep": 'hr' }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
index e1c574ec7ba65..d58e22e63d71e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
@@ -17,10 +17,10 @@
package org.apache.spark.sql.connector
+import org.apache.spark.sql.{sources, Column, Row}
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.Row
import org.apache.spark.sql.classic.MergeIntoWriter
-import org.apache.spark.sql.connector.catalog.Column
+import org.apache.spark.sql.connector.catalog.{Aborted, Committed}
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.TableInfo
import org.apache.spark.sql.functions._
@@ -31,6 +31,103 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase {
import testImplicits._
+ private def targetTableCol(colName: String): Column = {
+ col(tableNameAsString + "." + colName)
+ }
+
+ test("self merge with transactional checks") {
+ // create table
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create a source on top of itself that will be fully resolved and analyzed
+ val sourceDF = spark.table(tableNameAsString)
+ .where("salary == 100")
+ .as("source")
+ sourceDF.queryExecution.assertAnalyzed()
+
+ // merge into table using the source on top of itself
+ val (txn, txnTables) = executeTransaction {
+ sourceDF
+ .mergeInto(
+ tableNameAsString,
+ $"source.pk" === targetTableCol("pk") && targetTableCol("dep") === "hr")
+ .whenMatched()
+ .update(Map("salary" -> targetTableCol("salary").plus(1)))
+ .whenNotMatched()
+ .insertAll()
+ .merge()
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 1)
+ assert(table.version() == "2")
+
+ // check all table scans
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.size == 4)
+
+ // check table scans as MERGE target
+ val numTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ assert(numTargetScans == 2)
+
+ // check table scans as MERGE source
+ val numSourceScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("salary", 100) => true
+ case _ => false
+ }
+ assert(numSourceScans == 2)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 101, "hr"), // update
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged
+ }
+
+ for (alterClause <- Seq(
+ "ADD COLUMN new_col INT",
+ "DROP COLUMN salary",
+ "ALTER COLUMN salary TYPE BIGINT",
+ "ALTER COLUMN pk DROP NOT NULL"))
+ test(s"self merge fails when source schema changes after analysis - DDL: $alterClause" ) {
+ withTable(tableNameAsString) {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val sourceDF = spark.table(tableNameAsString).where("salary == 100").as("source")
+ sourceDF.queryExecution.assertAnalyzed()
+
+ sql(s"ALTER TABLE $tableNameAsString $alterClause")
+
+ val e = intercept[AnalysisException] {
+ sourceDF
+ .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk"))
+ .whenMatched()
+ .update(Map("salary" -> targetTableCol("salary").plus(1)))
+ .merge()
+ }
+
+ assert(
+ e.getCondition == "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH",
+ alterClause)
+ assert(catalog.lastTransaction.currentState == Aborted, alterClause)
+ assert(catalog.lastTransaction.isClosed, alterClause)
+ }
+ }
+
test("merge into empty table with NOT MATCHED clause") {
withTempView("source") {
createTable("pk INT NOT NULL, salary INT, dep STRING")
@@ -979,6 +1076,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase {
}
test("SPARK-54157: version is refreshed when source is V2 table") {
+ import org.apache.spark.sql.connector.catalog.Column
val sourceTable = "cat.ns1.source_table"
withTable(sourceTable) {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
@@ -1026,6 +1124,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase {
}
test("SPARK-54444: any schema changes after analysis are prohibited") {
+ import org.apache.spark.sql.connector.catalog.Column
val sourceTable = "cat.ns1.source_table"
withTable(sourceTable) {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
index 7c0e503705c7e..0423fe66a2067 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, In, Not}
import org.apache.spark.sql.catalyst.optimizer.BuildLeft
-import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, InMemoryTable, TableInfo}
+import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, InMemoryTable, TableInfo}
import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue}
import org.apache.spark.sql.connector.write.MergeSummary
import org.apache.spark.sql.execution.SparkPlan
@@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.MergeRowsExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources
import org.apache.spark.sql.types.{IntegerType, StringType}
abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
@@ -38,6 +39,309 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
protected def deltaMerge: Boolean = false
+ test("self merge with transactional checks") {
+ // create table
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // merge into table using a subquery on top of itself
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING (SELECT * FROM $tableNameAsString WHERE salary = 100) s
+ |ON t.pk = s.pk AND t.dep = 'hr'
+ |WHEN MATCHED THEN
+ | UPDATE SET salary = t.salary + 1
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin)
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 1)
+ assert(table.version() == "2")
+
+ // check all table scans
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumScans = if (deltaMerge) 2 else 4
+ assert(targetTxnTable.scanEvents.size == expectedNumScans)
+
+ // check table scans as MERGE target
+ val numTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ val expectedNumTargetScans = if (deltaMerge) 1 else 2
+ assert(numTargetScans == expectedNumTargetScans)
+
+ // check table scans as MERGE source
+ val numSourceScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("salary", 100) => true
+ case _ => false
+ }
+ val expectedNumSourceScans = if (deltaMerge) 1 else 2
+ assert(numSourceScans == expectedNumSourceScans)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 101, "hr"), // update
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged
+ }
+
+ test("merge into table with analysis failure and transactional checks") {
+ createAndInitTable(
+ "pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'support'), (4, 400, 'finance')")
+
+ val exception = intercept[AnalysisException] {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING $sourceNameAsString s
+ |ON t.pk = s.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET salary = s.invalid_column
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending')
+ |""".stripMargin)
+ }
+
+ assert(exception.getMessage.contains("invalid_column"))
+ assert(catalog.lastTransaction.currentState == Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ test("merge into table using view with transactional checks") {
+ withView("temp_view") {
+ // create target table
+ createAndInitTable(
+ "pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)")
+ sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)")
+
+ // create view on top of source and target tables
+ sql(
+ s"""CREATE VIEW temp_view AS
+ |SELECT s.pk, s.salary, t.dep
+ |FROM $sourceNameAsString s
+ |LEFT JOIN (
+ | SELECT * FROM $tableNameAsString WHERE pk < 10
+ |) t ON s.pk = t.pk
+ |""".stripMargin)
+
+ // merge into target table using view
+ val (txn, txnTables) = executeTransaction {
+ sql(s"""MERGE INTO $tableNameAsString t
+ |USING temp_view s
+ |ON t.pk = s.pk AND t.dep = 'hr'
+ |WHEN MATCHED THEN
+ | UPDATE SET salary = s.salary, dep = s.dep
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending')
+ |""".stripMargin)
+ }
+
+ // check txn covers both tables and was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check target table was scanned correctly
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaMerge) 2 else 4
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ // check target table scans as MERGE target (dep = 'hr')
+ val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ val expectedNumMergeTargetScans = if (deltaMerge) 1 else 2
+ assert(numMergeTargetScans == expectedNumMergeTargetScans)
+
+ // check target table scans in view as MERGE source (pk < 10)
+ val numViewTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.LessThan("pk", 10L) => true
+ case _ => false
+ }
+ val expectedNumViewTargetScans = if (deltaMerge) 1 else 2
+ assert(numViewTargetScans == expectedNumViewTargetScans)
+
+ // check source table scans in view as MERGE source (no predicate)
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaMerge) 1 else 2
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 150, "hr"), // update
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"), // unchanged
+ Row(4, 400, "pending"))) // new
+ }
+ }
+
+ test("merge into table using nested view with transactional checks") {
+ withView("base_view", "nested_view") {
+ withTable(sourceNameAsString) {
+ // create target table
+ createAndInitTable(
+ "pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)")
+ sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)")
+
+ // create base view
+ sql(
+ s"""CREATE VIEW base_view AS
+ |SELECT s.pk, s.salary, t.dep
+ |FROM $sourceNameAsString s
+ |LEFT JOIN (
+ | SELECT * FROM $tableNameAsString WHERE pk < 10
+ |) t ON s.pk = t.pk
+ |""".stripMargin)
+
+ // create nested view on top of base view
+ sql(
+ s"""CREATE VIEW nested_view AS
+ |SELECT * FROM base_view
+ |""".stripMargin)
+
+ // merge into target table using nested view
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING nested_view s
+ |ON t.pk = s.pk AND t.dep = 'hr'
+ |WHEN MATCHED THEN
+ | UPDATE SET salary = s.salary, dep = s.dep
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending')
+ |""".stripMargin)
+ }
+
+ // check txn covers both tables and was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check target table was scanned correctly
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaMerge) 2 else 4
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ // check target table scans as MERGE target (dep = 'hr')
+ val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ val expectedNumMergeTargetScans = if (deltaMerge) 1 else 2
+ assert(numMergeTargetScans == expectedNumMergeTargetScans)
+
+ // check target table scans in view as MERGE source (pk < 10)
+ val numViewTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.LessThan("pk", 10L) => true
+ case _ => false
+ }
+ val expectedNumViewTargetScans = if (deltaMerge) 1 else 2
+ assert(numViewTargetScans == expectedNumViewTargetScans)
+
+ // check source table scans in view as MERGE source (no predicate)
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaMerge) 1 else 2
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 150, "hr"), // update
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"), // unchanged
+ Row(4, 400, "pending"))) // new
+ }
+ }
+ }
+
+ test("merge into table rewritten as INSERT with transactional checks") {
+ withTable(sourceNameAsString) {
+ // create target table
+ createAndInitTable(
+ "pk INT, value STRING, dep STRING",
+ """{ "pk": 1, "value": "a", "dep": "hr" }
+ |{ "pk": 2, "value": "b", "dep": "finance" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT, value STRING, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (3, 'c', 'hr'), (4, 'd', 'software')")
+
+ // merge into target with only WHEN NOT MATCHED clauses (rewritten as insert)
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING $sourceNameAsString s
+ |ON t.pk = s.pk
+ |WHEN NOT MATCHED AND s.pk < 4 THEN
+ | INSERT (pk, value, dep) VALUES (s.pk, concat(s.value, '_low'), s.dep)
+ |WHEN NOT MATCHED AND s.pk >= 4 THEN
+ | INSERT (pk, value, dep) VALUES (s.pk, concat(s.value, '_high'), s.dep)
+ |""".stripMargin)
+ }
+
+ // check txn covers both tables and was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check target table was scanned correctly
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.size == 1)
+
+ // check source table was scanned correctly
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ assert(sourceTxnTable.scanEvents.size == 1)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, "a", "hr"), // unchanged
+ Row(2, "b", "finance"), // unchanged
+ Row(3, "c_low", "hr"), // inserted via first NOT MATCHED clause
+ Row(4, "d_high", "software"))) // inserted via second NOT MATCHED clause
+ }
+ }
+
test("merge into table with expression-based default values") {
val columns = Array(
Column.create("pk", IntegerType),
@@ -671,6 +975,131 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
}
}
+ test("merge with CTE with transactional checks") {
+ withTable(sourceNameAsString) {
+ // create target table
+ createAndInitTable(
+ "pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')")
+
+ // merge into target table using CTE
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""WITH cte AS (
+ | SELECT pk, salary + 50 AS salary, dep
+ | FROM $sourceNameAsString
+ | WHERE salary > 100
+ |)
+ |MERGE INTO $tableNameAsString t
+ |USING cte s
+ |ON t.pk = s.pk AND t.dep = 'hr'
+ |WHEN MATCHED THEN
+ | UPDATE SET salary = s.salary
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending')
+ |""".stripMargin)
+ }
+
+ // check txn covers both tables and was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check target table was scanned correctly
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaMerge) 1 else 2
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ // check target table scans as MERGE target (dep = 'hr')
+ val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ assert(numMergeTargetScans == expectedNumTargetScans)
+
+ // check source table was scanned correctly
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaMerge) 1 else 2
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ // check source table scans in CTE (salary > 100)
+ val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count {
+ case sources.GreaterThan("salary", 100) => true
+ case _ => false
+ }
+ assert(numCteSourceScans == expectedNumSourceScans)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 200, "hr"), // updated (150 + 50)
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"), // unchanged
+ Row(4, 450, "pending"))) // inserted (400 + 50)
+ }
+ }
+
+ test("merge with cached source and transactional checks") {
+ withTable(sourceNameAsString) {
+ // create target table
+ createAndInitTable(
+ "pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create and populate source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'support'), (4, 400, 'finance')")
+
+ // Cache source table before the transaction. Make sure when the transation is active the
+ // catalog still creates a transaction table.
+ spark.table(sourceNameAsString).cache()
+
+ try {
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING $sourceNameAsString s
+ |ON t.pk = s.pk AND t.dep = 'hr'
+ |WHEN MATCHED THEN
+ | UPDATE SET salary = s.salary, dep = s.dep
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending')
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ // both target and source must have been read through the transaction catalog
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+ assert(txnTables(sourceNameAsString).scanEvents.nonEmpty)
+ assert(txnTables(tableNameAsString).scanEvents.nonEmpty)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 150, "support"), // matched and updated
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"), // unchanged (no match in source)
+ Row(4, 400, "pending"))) // not matched, inserted
+ } finally {
+ spark.catalog.uncacheTable(sourceNameAsString)
+ }
+ }
+ }
+
test("merge with subquery as source") {
withTempView("source") {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
@@ -2223,6 +2652,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
sql(query)
}
assert(e.getMessage.contains("ON search condition of the MERGE statement"))
+ assert(catalog.lastTransaction.currentState == Aborted)
+ assert(catalog.lastTransaction.isClosed)
}
private def assertMetric(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
index 8c51fb17b2cf4..565e209b021a6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
@@ -28,16 +28,16 @@ import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expr
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReplaceData, WriteDelta}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY
-import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, TableInfo, Update, Write}
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, Table, TableInfo, Txn, TxnTable, Update, Write}
import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference}
import org.apache.spark.sql.connector.expressions.Transform
-import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan}
+import org.apache.spark.sql.connector.write.RowLevelOperationTable
+import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType, StructField, StructType}
-import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._
@@ -82,6 +82,7 @@ abstract class RowLevelOperationSuiteBase
protected val namespace: Array[String] = Array("ns1")
protected val ident: Identifier = Identifier.of(namespace, "test_table")
protected val tableNameAsString: String = "cat." + ident.toString
+ protected val sourceNameAsString: String = "cat.ns1.source_table"
protected def extraTableProps: java.util.Map[String, String] = {
Collections.emptyMap[String, String]
@@ -135,22 +136,28 @@ abstract class RowLevelOperationSuiteBase
// executes an operation and keeps the executed plan
protected def executeAndKeepPlan(func: => Unit): SparkPlan = {
- var executedPlan: SparkPlan = null
+ withQueryExecutionsCaptured(spark)(func) match {
+ case Seq(qe) => stripAQEPlan(qe.executedPlan)
+ case other => fail(s"expected only one query execution, but got ${other.size}")
+ }
+ }
- val listener = new QueryExecutionListener {
- override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
- executedPlan = qe.executedPlan
- }
- override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
+ protected def executeTransaction(func: => Unit): (Txn, Map[String, TxnTable]) = {
+ val tables = withQueryExecutionsCaptured(spark)(func).flatMap { qe =>
+ collectWithSubqueries(qe.executedPlan) {
+ case BatchScanExec(_, _, _, _, table: TxnTable, _) => table
+ case BatchScanExec(_, _, _, _, RowLevelOperationTable(table: TxnTable, _), _) => table
}
}
- spark.listenerManager.register(listener)
-
- func
-
- sparkContext.listenerBus.waitUntilEmpty()
+ (catalog.lastTransaction, indexByName(tables))
+ }
- stripAQEPlan(executedPlan)
+ protected def indexByName[T <: Table](tables: Seq[T]): Map[String, T] = {
+ tables.groupBy(_.name).map {
+ case (name, sameNameTables) =>
+ val Seq(table) = sameNameTables.distinct
+ name -> table
+ }
}
// executes an operation and extracts conditions from ReplaceData or WriteDelta
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala
index ac0bf3bdba9ce..2ca4235f96b55 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.connector
import org.apache.spark.SparkRuntimeException
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, TableChange, TableInfo}
+import org.apache.spark.sql.{sources, AnalysisException, Row}
+import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, TableChange, TableInfo}
import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType}
@@ -28,6 +28,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase {
import testImplicits._
+ protected def deltaUpdate: Boolean = false
+
test("update table containing added column with default value") {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
"""{ "pk": 1, "salary": 100, "dep": "hr" }
@@ -762,4 +764,325 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase {
Row(5)))
}
}
+
+ test("update with analysis failure and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ val exception = intercept[AnalysisException] {
+ sql(s"UPDATE $tableNameAsString SET invalid_column = -1")
+ }
+
+ assert(exception.getMessage.contains("invalid_column"))
+ assert(catalog.lastTransaction.currentState == Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ test("update with CTE and transactional checks") {
+ // create table
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')")
+
+ // update using CTE
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""WITH cte AS (
+ | SELECT pk, salary + 50 AS adjusted_salary, dep
+ | FROM $sourceNameAsString
+ | WHERE salary > 100
+ |)
+ |UPDATE $tableNameAsString t
+ |SET salary = -1
+ |WHERE t.dep = 'hr' AND EXISTS (SELECT 1 FROM cte WHERE cte.pk = t.pk)
+ |""".stripMargin)
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check target table was scanned correctly
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaUpdate) 1 else 3
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ // check target table scans for UPDATE condition (dep = 'hr')
+ val numUpdateTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ assert(numUpdateTargetScans == expectedNumTargetScans)
+
+ // check source table was scanned correctly
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaUpdate) 1 else 4
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ // check source table scans in CTE (salary > 100)
+ val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count {
+ case sources.GreaterThan("salary", 100) => true
+ case _ => false
+ }
+ assert(numCteSourceScans == expectedNumSourceScans)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, -1, "hr"), // updated
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged (no matching pk in source)
+ }
+
+ test("update with subquery on source table and transactional checks") {
+ // create target table
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')")
+
+ // update using an uncorrelated IN subquery that reads from a transactional catalog table
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""UPDATE $tableNameAsString
+ |SET salary = -1
+ |WHERE pk IN (SELECT pk FROM $sourceNameAsString WHERE dep = 'hr')
+ |""".stripMargin)
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check source table was scanned correctly (dep = 'hr' filter in the subquery)
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaUpdate) 1 else 4
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ val numSubquerySourceScans = sourceTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ assert(numSubquerySourceScans == expectedNumSourceScans)
+
+ // check target table was scanned correctly
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaUpdate) 1 else 3
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, -1, "hr"), // updated (pk 1 is in subquery result)
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged (pk 3 not in subquery result)
+ }
+
+ test("update with uncorrelated scalar subquery on source table and transactional checks") {
+ // create target table
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')")
+
+ // update using an uncorrelated scalar subquery in the SET clause that reads from a
+ // transactional catalog table; scalar subqueries are executed as SubqueryExec at runtime
+ // and cannot be rewritten as joins
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""UPDATE $tableNameAsString
+ |SET salary = (SELECT max(salary) FROM $sourceNameAsString WHERE dep = 'hr')
+ |WHERE dep = 'hr'
+ |""".stripMargin)
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check source table was scanned via the transaction catalog
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ assert(sourceTxnTable.scanEvents.nonEmpty)
+ assert(sourceTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ // check target table was scanned via the transaction catalog
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.nonEmpty)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 150, "hr"), // updated (max salary in source for 'hr' is 150)
+ Row(2, 200, "software"), // unchanged
+ Row(3, 150, "hr"))) // updated
+ }
+
+ test("update with constraint violation and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ val exception = intercept[SparkRuntimeException] {
+ executeTransaction {
+ sql(
+ s"""UPDATE $tableNameAsString
+ |SET pk = NULL
+ |WHERE dep = 'hr'
+ |""".stripMargin) // NULL violates NOT NULL constraint
+ }
+ }
+
+ assert(exception.getMessage.contains("NOT_NULL_ASSERT_VIOLATION"))
+ assert(catalog.lastTransaction.currentState == Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ test("update using view with transactional checks") {
+ withView("temp_view") {
+ // create target table
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)")
+ sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)")
+
+ // create view on top of source and target tables
+ sql(
+ s"""CREATE VIEW temp_view AS
+ |SELECT s.pk, s.salary, t.dep
+ |FROM $sourceNameAsString s
+ |LEFT JOIN (
+ | SELECT * FROM $tableNameAsString WHERE pk < 10
+ |) t ON s.pk = t.pk
+ |""".stripMargin)
+
+ // update target table using view
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""UPDATE $tableNameAsString t
+ |SET salary = -1
+ |WHERE t.dep = 'hr' AND EXISTS (SELECT 1 FROM temp_view v WHERE v.pk = t.pk)
+ |""".stripMargin)
+ }
+
+ // check txn covers both tables and was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check target table was scanned correctly
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaUpdate) 2 else 7
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ // check target table scans as UPDATE target (dep = 'hr')
+ val numUpdateTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ val expectedNumUpdateTargetScans = if (deltaUpdate) 1 else 3
+ assert(numUpdateTargetScans == expectedNumUpdateTargetScans)
+
+ // check target table scans in view as source (pk < 10)
+ val numViewTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.LessThan("pk", 10L) => true
+ case _ => false
+ }
+ val expectedNumViewTargetScans = if (deltaUpdate) 1 else 4
+ assert(numViewTargetScans == expectedNumViewTargetScans)
+
+ // check source table scans in view
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaUpdate) 1 else 4
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, -1, "hr"), // updated from view
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged (no matching pk in source)
+ }
+ }
+
+ test("df.explain() on update with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ // sql() is lazy, but explain() forces executedPlan.
+ sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'").explain()
+
+ assert(catalog.lastTransaction != null)
+ assert(catalog.lastTransaction.currentState == Committed)
+ assert(catalog.lastTransaction.isClosed)
+ assert(table.version() == "2")
+
+ // The UPDATE was actually executed, not just planned.
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, -1, "hr"), // updated
+ Row(2, 200, "software"))) // unchanged
+ }
+
+ test("EXPLAIN UPDATE SQL with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ // EXPLAIN UPDATE only plans the command, it does not execute the write.
+ sql(s"EXPLAIN UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'")
+
+ // A transaction should not have started at all.
+ assert(catalog.transaction === null)
+
+ // The UPDATE was not executed. Data is unchanged.
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software")))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala
new file mode 100644
index 0000000000000..141d5966b4b6c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import scala.concurrent.duration._
+
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.classic
+import org.apache.spark.sql.execution.QueryExecution
+
+/**
+ * Benchmark to measure the overhead of cloning the analyzer for transactional query execution.
+ * Each transactional query creates a new [[Analyzer]] instance via
+ * [[Analyzer.withCatalogManager]], which shares all rules with the original but carries a
+ * per-query [[org.apache.spark.sql.connector.catalog.CatalogManager]]. This benchmark checks
+ * whether the cloning introduces measurable overhead.
+ *
+ * To run this benchmark:
+ * {{{
+ * build/sbt "sql/Test/runMain