diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 8c032d31cff61..49d941cfb6be1 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -8192,6 +8192,23 @@ ], "sqlState" : "0A000" }, + "UNSUPPORTED_STREAMING_SCHEMA_EVOLUTION" : { + "message" : [ + "Schema evolution is not supported for this streaming write:" + ], + "subClass" : { + "CONTINUOUS_TRIGGER" : { + "message" : [ + "Continuous triggers are not supported. Use a micro-batch trigger instead." + ] + }, + "NOT_V2_TABLE" : { + "message" : [ + "The sink is not a V2 table. Schema evolution requires a V2 table that supports the AUTOMATIC_SCHEMA_EVOLUTION capability." + ] + } + } + }, "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY" : { "message" : [ "Unsupported subquery expression:" diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index cb5ecc728c441..c0ff3d19035b8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -170,6 +170,19 @@ abstract class DataStreamWriter[T] extends WriteConfigMethods[DataStreamWriter[T foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) } + /** + * Enables automatic schema evolution for the streaming write. When enabled, if the source + * schema has columns not present in the sink table (or type changes), the sink table schema + * will be evolved to accommodate the new schema before data is written. The sink table must + * support the `AUTOMATIC_SCHEMA_EVOLUTION` capability. + * + * Schema evolution is applied at query analysis time: when the streaming query is started + * (or restarted after failure), the table schema is evolved if needed. + * + * @since 4.2.0 + */ + def withSchemaEvolution(): this.type + /** * Starts the execution of the streaming query, which will continually output results to the * given path as new data arrives. The returned diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala index 884a4165d077e..44ae76611483a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala @@ -25,6 +25,8 @@ import org.apache.spark.sql.streaming.OutputMode /** * Used to create a [[StreamExecution]]. + * + * @param withSchemaEvolution Whether to evolve the sink table schema to match the source. */ case class WriteToStream( name: String, @@ -34,7 +36,8 @@ case class WriteToStream( deleteCheckpointOnStop: Boolean, inputQuery: LogicalPlan, catalogAndIdent: Option[(TableCatalog, Identifier)] = None, - catalogTable: Option[CatalogTable]) extends UnaryNode { + catalogTable: Option[CatalogTable], + withSchemaEvolution: Boolean) extends UnaryNode { override def isStreaming: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala index 7015d0dd3b2cc..99c0c158fcc15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.streaming.{OutputMode, Trigger} * for unsupported operations, which happens during resolution. * @param inputQuery The analyzed query plan from the streaming DataFrame. * @param catalogAndIdent Catalog and identifier for the sink, set when it is a V2 catalog table + * @param withSchemaEvolution Whether to evolve the sink table schema to match the source. */ case class WriteToStreamStatement( userSpecifiedName: Option[String], @@ -55,8 +56,9 @@ case class WriteToStreamStatement( hadoopConf: Configuration, trigger: Trigger, inputQuery: LogicalPlan, - catalogAndIdent: Option[(TableCatalog, Identifier)] = None, - catalogTable: Option[CatalogTable] = None) extends UnaryNode { + catalogAndIdent: Option[(TableCatalog, Identifier)], + catalogTable: Option[CatalogTable], + withSchemaEvolution: Boolean) extends UnaryNode { override def isStreaming: Boolean = true diff --git a/sql/connect/common/src/main/protobuf/spark/connect/commands.proto b/sql/connect/common/src/main/protobuf/spark/connect/commands.proto index c22e76e3542f5..29220a44c16bd 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/commands.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/commands.proto @@ -249,6 +249,9 @@ message WriteStreamOperationStart { // (Optional) Columns used for clustering the table. repeated string clustering_column_names = 15; + + // (Optional) Enable automatic schema evolution for the streaming write. + bool with_schema_evolution = 16; } message StreamingForeachFunction { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataStreamWriter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataStreamWriter.scala index ffa11b5d7ab0d..a1d4f6d2eb5e1 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataStreamWriter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataStreamWriter.scala @@ -82,6 +82,12 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) this } + /** @inheritdoc */ + def withSchemaEvolution(): this.type = { + sinkBuilder.setWithSchemaEvolution(true) + this + } + /** @inheritdoc */ def format(source: String): this.type = { sinkBuilder.setFormat(source) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 1b47f2f56a476..46b5ccc41293c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -3491,6 +3491,10 @@ class SparkConnectPlanner( writer.queryName(writeOp.getQueryName) } + if (writeOp.getWithSchemaEvolution) { + writer.withSchemaEvolution() + } + if (writeOp.hasForeachWriter) { if (writeOp.getForeachWriter.hasPythonFunction) { val foreach = writeOp.getForeachWriter.getPythonFunction diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala index 38483395ec8c5..a80c136fdf000 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala @@ -83,6 +83,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends streaming.D this } + /** @inheritdoc */ + def withSchemaEvolution(): this.type = { + this.schemaEvolution = true + this + } + /** @inheritdoc */ def format(source: String): this.type = { this.source = source @@ -205,7 +211,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends streaming.D import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ tableInstance match { case t: SupportsWrite if t.supports(STREAMING_WRITE) => - startQuery(t, extraOptions, catalogAndIdent = Some(catalog.asTableCatalog, identifier)) + startQuery(t, extraOptions, catalogAndIdent = Some(catalog.asTableCatalog, identifier), + withSchemaEvolution = schemaEvolution) case t: V2TableWithV1Fallback => writeToV1Table(t.v1Table) case t: V1Table => @@ -244,7 +251,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends streaming.D throw QueryCompilationErrors.sourceNotSupportedWithContinuousTriggerError(source) } val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc) - startQuery(sink, extraOptions, catalogTable = catalogTable) + startQuery(sink, extraOptions, catalogTable = catalogTable, + withSchemaEvolution = schemaEvolution) } else { val cls = DataSource.lookupDataSource(source, ds.sparkSession.sessionState.conf) val disabledSources = @@ -290,7 +298,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends streaming.D createV1Sink(optionsWithPath) } - startQuery(sink, optionsWithPath, catalogTable = catalogTable) + startQuery(sink, optionsWithPath, catalogTable = catalogTable, + withSchemaEvolution = schemaEvolution) } } @@ -299,7 +308,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends streaming.D newOptions: CaseInsensitiveMap[String], recoverFromCheckpoint: Boolean = true, catalogAndIdent: Option[(TableCatalog, Identifier)] = None, - catalogTable: Option[CatalogTable] = None): StreamingQuery = { + catalogTable: Option[CatalogTable] = None, + withSchemaEvolution: Boolean = false): StreamingQuery = { if (trigger.isInstanceOf[RealTimeTrigger]) { RealTimeModeAllowlist.checkAllowedSink( sink, @@ -321,7 +331,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends streaming.D recoverFromCheckpointLocation = recoverFromCheckpoint, trigger = trigger, catalogAndIdent = catalogAndIdent, - catalogTable = catalogTable) + catalogTable = catalogTable, + withSchemaEvolution = withSchemaEvolution) } private def createV1Sink(optionsWithPath: CaseInsensitiveMap[String]): Sink = { @@ -444,6 +455,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends streaming.D private var partitioningColumns: Option[Seq[String]] = None private var clusteringColumns: Option[Seq[String]] = None + + private var schemaEvolution: Boolean = false } object DataStreamWriter { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala index 72ae3b21d662a..1313eb05187a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala @@ -24,7 +24,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.jdk.CollectionConverters._ -import org.apache.spark.SparkIllegalArgumentException +import org.apache.spark.{SparkIllegalArgumentException, SparkUnsupportedOperationException} import org.apache.spark.annotation.Evolving import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{CLASS_NAME, QUERY_ID, RUN_ID} @@ -186,7 +186,8 @@ class StreamingQueryManager private[sql] ( trigger: Trigger, triggerClock: Clock, catalogAndIdent: Option[(TableCatalog, Identifier)] = None, - catalogTable: Option[CatalogTable] = None): StreamingQueryWrapper = { + catalogTable: Option[CatalogTable] = None, + withSchemaEvolution: Boolean = false): StreamingQueryWrapper = { val analyzedPlan = df.queryExecution.analyzed df.queryExecution.assertAnalyzed() @@ -216,7 +217,8 @@ class StreamingQueryManager private[sql] ( trigger, analyzedPlan, catalogAndIdent, - catalogTable) + catalogTable, + withSchemaEvolution) val analyzedStreamWritePlan = sparkSession.sessionState.executePlan(dataStreamWritePlan).analyzed @@ -224,6 +226,12 @@ class StreamingQueryManager private[sql] ( (sink, trigger) match { case (_: SupportsWrite, trigger: ContinuousTrigger) => + if (withSchemaEvolution) { + throw new SparkUnsupportedOperationException( + errorClass = + "UNSUPPORTED_STREAMING_SCHEMA_EVOLUTION.CONTINUOUS_TRIGGER", + messageParameters = Map.empty[String, String]) + } new StreamingQueryWrapper(new ContinuousExecution( sparkSession, trigger, @@ -287,7 +295,8 @@ class StreamingQueryManager private[sql] ( trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock(), catalogAndIdent: Option[(TableCatalog, Identifier)] = None, - catalogTable: Option[CatalogTable] = None): StreamingQuery = { + catalogTable: Option[CatalogTable] = None, + withSchemaEvolution: Boolean = false): StreamingQuery = { val query = createQuery( userSpecifiedName, userSpecifiedCheckpointLocation, @@ -300,7 +309,8 @@ class StreamingQueryManager private[sql] ( trigger, triggerClock, catalogAndIdent, - catalogTable) + catalogTable, + withSchemaEvolution) // scalastyle:on argcount // The following code block checks if a stream with the same name or id is running. Then it diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index d8e871bcf4824..83f9bc5e17de1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -91,7 +91,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { o.copy(write = Some(write), query = newQuery) case WriteToMicroBatchDataSource( - relationOpt, table, query, queryId, options, outputMode, Some(batchId)) => + relationOpt, table, query, queryId, options, outputMode, _, Some(batchId)) => val writeOptions = mergeOptions( options, relationOpt.map(r => r.options.asCaseSensitiveMap.asScala.toMap).getOrElse(Map.empty)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala index 973af04e04307..b4724c3571b1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkIllegalArgumentException, SparkIllegalStateException} +import org.apache.spark.{SparkIllegalArgumentException, SparkIllegalStateException, SparkUnsupportedOperationException} import org.apache.spark.internal.LogKeys import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -354,13 +354,20 @@ class MicroBatchExecution( } WriteToMicroBatchDataSource( relationOpt, - table = s, + sinkTable = s, query = _logicalPlan, queryId = id.toString, extraOptions, - outputMode) + outputMode, + withSchemaEvolution = plan.withSchemaEvolution) case s: Sink => + if (plan.withSchemaEvolution) { + throw new SparkUnsupportedOperationException( + errorClass = + "UNSUPPORTED_STREAMING_SCHEMA_EVOLUTION.NOT_V2_TABLE", + messageParameters = Map.empty[String, String]) + } // SinkV1 is not compatible with Real-Time Mode due to API limitations. // SinkV1 does not support writing outputs row by row. if (trigger.isInstanceOf[RealTimeTrigger]) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ResolveWriteToStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ResolveWriteToStream.scala index ff0d71d0f0759..d7815c43055c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ResolveWriteToStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ResolveWriteToStream.scala @@ -72,7 +72,8 @@ object ResolveWriteToStream extends Rule[LogicalPlan] { deleteCheckpointOnStop, s.inputQuery, s.catalogAndIdent, - s.catalogTable) + s.catalogTable, + s.withSchemaEvolution) } def resolveCheckpointLocation(s: WriteToStreamStatement): (String, Boolean) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala index 0a33093dcbcea..57cb1f3e15562 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution.streaming.sources +import org.apache.spark.sql.catalyst.analysis.{NamedRelation, ResolveSchemaEvolution} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} -import org.apache.spark.sql.connector.catalog.SupportsWrite -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode, WriteWithSchemaEvolution} +import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.connector.catalog.{SupportsWrite, TableChange, TableWritePrivilege} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2CatalogAndIdentifier} import org.apache.spark.sql.streaming.OutputMode /** @@ -29,19 +32,59 @@ import org.apache.spark.sql.streaming.OutputMode * Note that this logical plan does not have a corresponding physical plan, as it will be converted * to [[org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 WriteToDataSourceV2]] * with [[MicroBatchWrite]] before execution. + * + * @param withSchemaEvolution Whether to evolve the sink table schema to match the source. */ case class WriteToMicroBatchDataSource( relation: Option[DataSourceV2Relation], - table: SupportsWrite, + sinkTable: SupportsWrite, query: LogicalPlan, queryId: String, writeOptions: Map[String, String], outputMode: OutputMode, + override val withSchemaEvolution: Boolean, batchId: Option[Long] = None) - extends UnaryNode { + extends UnaryNode with WriteWithSchemaEvolution { override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil + final override val nodePatterns = Seq(COMMAND) + + override def table: LogicalPlan = relation.getOrElse { + throw new IllegalStateException( + "Cannot access table for schema evolution: no DataSourceV2Relation is set.") + } + + override lazy val schemaEvolutionReady: Boolean = + relation.exists(_.resolved) && query.resolved + + override def pendingSchemaChanges: Seq[TableChange] = { + if (relation.isEmpty || !schemaEvolutionEnabled || !schemaEvolutionReady) { + return Seq.empty + } + + val currentRelation = relation.get match { + case r @ ExtractV2CatalogAndIdentifier(catalog, ident) => + // Loading the current table from the catalog ensures we don't use a stale schema. + val currentTable = catalog.loadTable(ident) + r.copy( + table = currentTable, + output = DataTypeUtils.toAttributes(currentTable.columns)) + case r => r + } + ResolveSchemaEvolution.computeSupportedSchemaChanges( + currentRelation, query.schema, isByName = true).toSeq + } + + override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT) + + override def withNewTable(newTable: NamedRelation): WriteToMicroBatchDataSource = { + val newRelation = newTable.asInstanceOf[DataSourceV2Relation] + copy( + relation = Some(newRelation), + sinkTable = newRelation.table.asInstanceOf[SupportsWrite]) + } + def withNewBatchId(batchId: Long): WriteToMicroBatchDataSource = { copy(batchId = Some(batchId)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingSchemaEvolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingSchemaEvolutionSuite.scala new file mode 100644 index 0000000000000..058a796eeb904 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingSchemaEvolutionSuite.scala @@ -0,0 +1,696 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog +import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} +import org.apache.spark.sql.types._ + +/** + * Tests for schema evolution in streaming writes using DataSourceV2. + * + * Schema evolution happens at query analysis time: when a streaming query is started (or + * restarted) and the source schema has columns not present in the sink table, the table is + * evolved to include those columns before any data is written. + */ +class StreamingSchemaEvolutionSuite + extends StreamTest with BeforeAndAfter { + + import testImplicits._ + + private val catalogName = "testcat" + private val namespace = "ns" + private val tableIdent = s"$catalogName.$namespace.test_table" + + before { + spark.conf.set( + s"spark.sql.catalog.$catalogName", classOf[InMemoryTableCatalog].getName) + sql(s"CREATE NAMESPACE IF NOT EXISTS $catalogName.$namespace") + } + + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.unsetConf(s"spark.sql.catalog.$catalogName") + } + + test("streaming write with extra source column adds column to table") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql(s"CREATE TABLE $tableIdent (id INT, data STRING)") + + val input = MemoryStream[(Int, String, Double)] + val df = input.toDF().toDF("id", "data", "amount") + + val query = df.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input.addData((1, "a", 10.0), (2, "b", 20.0)) + query.processAllAvailable() + } finally { + query.stop() + } + + val result = spark.table(tableIdent) + checkAnswer(result, Seq(Row(1, "a", 10.0), Row(2, "b", 20.0))) + assert(result.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType), + StructField("amount", DoubleType)))) + } + } + } + + test("streaming write with matching schema - no evolution needed") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql(s"CREATE TABLE $tableIdent (id INT, data STRING)") + + val input = MemoryStream[(Int, String)] + val df = input.toDF().toDF("id", "data") + + val query = df.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input.addData((1, "a"), (2, "b")) + query.processAllAvailable() + } finally { + query.stop() + } + + val result = spark.table(tableIdent) + checkAnswer(result, Seq(Row(1, "a"), Row(2, "b"))) + assert(result.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType)))) + } + } + } + + test("streaming write evolves schema then processes multiple batches") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql(s"CREATE TABLE $tableIdent (id INT, data STRING)") + + val input = MemoryStream[(Int, String, Double)] + val df = input.toDF().toDF("id", "data", "amount") + + val query = df.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + // First batch triggers schema evolution. + input.addData((1, "a", 10.0)) + query.processAllAvailable() + + // Second batch should work without re-triggering evolution. + // Note: InMemoryTableCatalog creates new table instances on alterTable, + // so the second batch writes to the old instance. We just verify + // that the second batch completes without errors. + input.addData((2, "b", 20.0)) + query.processAllAvailable() + } finally { + query.stop() + } + + val result = spark.table(tableIdent) + assert(result.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType), + StructField("amount", DoubleType)))) + } + } + } + + test("streaming write with multiple extra columns") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql(s"CREATE TABLE $tableIdent (id INT)") + + val input = MemoryStream[(Int, String, Double)] + val df = input.toDF().toDF("id", "data", "amount") + + val query = df.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input.addData((1, "a", 10.0)) + query.processAllAvailable() + } finally { + query.stop() + } + + val result = spark.table(tableIdent) + checkAnswer(result, Seq(Row(1, "a", 10.0))) + assert(result.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType), + StructField("amount", DoubleType)))) + } + } + } + + test("table without AUTOMATIC_SCHEMA_EVOLUTION capability - schema not evolved") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql( + s"""CREATE TABLE $tableIdent (id INT, data STRING) + |TBLPROPERTIES ('auto-schema-evolution' = 'false')""".stripMargin) + + val input = MemoryStream[(Int, String, Double)] + val df = input.toDF().toDF("id", "data", "amount") + + val query = df.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input.addData((1, "a", 10.0)) + query.processAllAvailable() + } finally { + query.stop() + } + + // Schema should NOT have been evolved since the table lacks the capability. + val result = spark.table(tableIdent) + assert(result.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType)))) + } + } + } + + test("streaming write without withSchemaEvolution - schema not evolved") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql(s"CREATE TABLE $tableIdent (id INT, data STRING)") + + val input = MemoryStream[(Int, String, Double)] + val df = input.toDF().toDF("id", "data", "amount") + + // No .withSchemaEvolution() call. + val query = df.writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input.addData((1, "a", 10.0)) + query.processAllAvailable() + } finally { + query.stop() + } + + // Schema should NOT have been evolved since withSchemaEvolution was not called. + val result = spark.table(tableIdent) + assert(result.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType)))) + } + } + } + + test("streaming restart after schema evolution preserves data") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql(s"CREATE TABLE $tableIdent (id INT, data STRING)") + + // First query: write with matching schema. + val input1 = MemoryStream[(Int, String)] + val df1 = input1.toDF().toDF("id", "data") + + val query1 = df1.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input1.addData((1, "a"), (2, "b")) + query1.processAllAvailable() + } finally { + query1.stop() + } + + checkAnswer(spark.table(tableIdent), Seq(Row(1, "a"), Row(2, "b"))) + + // Second query: write with extra column, triggers schema evolution. + val input2 = MemoryStream[(Int, String, Double)] + val df2 = input2.toDF().toDF("id", "data", "amount") + + val query2 = df2.writeStream + .withSchemaEvolution() + .option( + "checkpointLocation", + s"${checkpointDir.getCanonicalPath}_2") + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input2.addData((3, "c", 30.0)) + query2.processAllAvailable() + } finally { + query2.stop() + } + + val result = spark.table(tableIdent) + checkAnswer(result, Seq( + Row(1, "a", null), + Row(2, "b", null), + Row(3, "c", 30.0))) + assert(result.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType), + StructField("amount", DoubleType)))) + } + } + } + + test("schema evolution on restart after sink table altered between batches") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql(s"CREATE TABLE $tableIdent (id INT, data STRING)") + + // First query: write with matching schema. + val input1 = MemoryStream[(Int, String)] + val df1 = input1.toDF().toDF("id", "data") + + val query1 = df1.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input1.addData((1, "a")) + query1.processAllAvailable() + } finally { + query1.stop() + } + + // Alter the table externally to add a column, simulating a schema change + // that happened while the query was down. + sql(s"ALTER TABLE $tableIdent ADD COLUMN amount DOUBLE") + + val input2 = MemoryStream[(Int, String, Double)] + val df2 = input2.toDF().toDF("id", "data", "amount") + + val query2 = df2.writeStream + .withSchemaEvolution() + .option( + "checkpointLocation", + s"${checkpointDir.getCanonicalPath}_2") + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input2.addData((2, "b", 20.0)) + query2.processAllAvailable() + } finally { + query2.stop() + } + + val result = spark.table(tableIdent) + assert(result.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType), + StructField("amount", DoubleType)))) + } + } + } + + test("schema evolution with Trigger.Once") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql(s"CREATE TABLE $tableIdent (id INT, data STRING)") + + val input = MemoryStream[(Int, String, Double)] + val df = input.toDF().toDF("id", "data", "amount") + input.addData((1, "a", 10.0), (2, "b", 20.0)) + + val query = df.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .trigger(Trigger.Once()) + .toTable(tableIdent) + + try { + query.processAllAvailable() + } finally { + query.stop() + } + + val result = spark.table(tableIdent) + checkAnswer(result, Seq(Row(1, "a", 10.0), Row(2, "b", 20.0))) + assert(result.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType), + StructField("amount", DoubleType)))) + } + } + } + + test("schema evolution with Trigger.AvailableNow") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql(s"CREATE TABLE $tableIdent (id INT, data STRING)") + + val input = MemoryStream[(Int, String, Double)] + val df = input.toDF().toDF("id", "data", "amount") + input.addData((1, "a", 10.0)) + + val query = df.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .trigger(Trigger.AvailableNow()) + .toTable(tableIdent) + + try { + query.processAllAvailable() + } finally { + query.stop() + } + + val result = spark.table(tableIdent) + checkAnswer(result, Seq(Row(1, "a", 10.0))) + assert(result.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType), + StructField("amount", DoubleType)))) + } + } + } + + test("incremental schema evolution across multiple restarts") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql(s"CREATE TABLE $tableIdent (id INT)") + + // First query: add "data" column. + val input1 = MemoryStream[(Int, String)] + val df1 = input1.toDF().toDF("id", "data") + + val query1 = df1.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input1.addData((1, "a")) + query1.processAllAvailable() + } finally { + query1.stop() + } + + assert(spark.table(tableIdent).schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType)))) + + // Second query: add "amount" column on top of the already-evolved schema. + val input2 = MemoryStream[(Int, String, Double)] + val df2 = input2.toDF().toDF("id", "data", "amount") + + val query2 = df2.writeStream + .withSchemaEvolution() + .option( + "checkpointLocation", + s"${checkpointDir.getCanonicalPath}_2") + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input2.addData((2, "b", 20.0)) + query2.processAllAvailable() + } finally { + query2.stop() + } + + val result = spark.table(tableIdent) + checkAnswer(result, Seq( + Row(1, "a", null), + Row(2, "b", 20.0))) + assert(result.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType), + StructField("amount", DoubleType)))) + + // Third query: no new columns - schema should stay the same. + val input3 = MemoryStream[(Int, String, Double)] + val df3 = input3.toDF().toDF("id", "data", "amount") + + val query3 = df3.writeStream + .withSchemaEvolution() + .option( + "checkpointLocation", + s"${checkpointDir.getCanonicalPath}_3") + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input3.addData((3, "c", 30.0)) + query3.processAllAvailable() + } finally { + query3.stop() + } + + val result2 = spark.table(tableIdent) + assert(result2.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType), + StructField("amount", DoubleType)))) + } + } + } + + test("stop and restart same query - schema evolved on restart") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql(s"CREATE TABLE $tableIdent (id INT, data STRING)") + + // First run: matching schema, no evolution. + val input1 = MemoryStream[(Int, String)] + val df1 = input1.toDF().toDF("id", "data") + + val query1 = df1.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input1.addData((1, "a")) + query1.processAllAvailable() + input1.addData((2, "b")) + query1.processAllAvailable() + } finally { + query1.stop() + } + + checkAnswer(spark.table(tableIdent), Seq(Row(1, "a"), Row(2, "b"))) + + // Second run: new source schema with extra column. + // Uses a different checkpoint since MemoryStream state can't be + // reused across restarts with a different schema. + val input2 = MemoryStream[(Int, String, Double)] + val df2 = input2.toDF().toDF("id", "data", "amount") + + val query2 = df2.writeStream + .withSchemaEvolution() + .option( + "checkpointLocation", + s"${checkpointDir.getCanonicalPath}_2") + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input2.addData((3, "c", 30.0)) + query2.processAllAvailable() + } finally { + query2.stop() + } + + val result = spark.table(tableIdent) + checkAnswer(result, Seq( + Row(1, "a", null), + Row(2, "b", null), + Row(3, "c", 30.0))) + } + } + } + + test("schema evolution with Trigger.Once across restart") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql(s"CREATE TABLE $tableIdent (id INT)") + + // First run: Trigger.Once with extra column. + val input1 = MemoryStream[(Int, String)] + val df1 = input1.toDF().toDF("id", "data") + input1.addData((1, "a")) + + val query1 = df1.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .trigger(Trigger.Once()) + .toTable(tableIdent) + + try { + query1.processAllAvailable() + } finally { + query1.stop() + } + + assert(spark.table(tableIdent).schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType)))) + + // Second run: Trigger.Once with yet another extra column. + val input2 = MemoryStream[(Int, String, Double)] + val df2 = input2.toDF().toDF("id", "data", "amount") + input2.addData((2, "b", 20.0)) + + val query2 = df2.writeStream + .withSchemaEvolution() + .option( + "checkpointLocation", + s"${checkpointDir.getCanonicalPath}_2") + .trigger(Trigger.Once()) + .toTable(tableIdent) + + try { + query2.processAllAvailable() + } finally { + query2.stop() + } + + val result = spark.table(tableIdent) + checkAnswer(result, Seq(Row(1, "a", null), Row(2, "b", 20.0))) + assert(result.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("data", StringType), + StructField("amount", DoubleType)))) + } + } + } + + + test("withSchemaEvolution rejected with continuous trigger") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql(s"CREATE TABLE $tableIdent (id INT, data STRING)") + + val input = MemoryStream[(Int, String, Double)] + val df = input.toDF().toDF("id", "data", "amount") + + val e = intercept[SparkUnsupportedOperationException] { + df.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .trigger(Trigger.Continuous("1 second")) + .toTable(tableIdent) + } + assert(e.getCondition == + "UNSUPPORTED_STREAMING_SCHEMA_EVOLUTION.CONTINUOUS_TRIGGER") + } + } + } + + test("withSchemaEvolution rejected for V1 sink") { + withTempDir { checkpointDir => + val input = MemoryStream[(Int, String)] + val df = input.toDF().toDF("id", "data") + input.addData((1, "a")) + + // foreachBatch creates a V1 ForeachBatchSink. The error surfaces + // when the streaming thread evaluates MicroBatchExecution.logicalPlan. + val query = df.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .foreachBatch { (batch: org.apache.spark.sql.Dataset[Row], _: Long) => + () + } + .start() + + try { + val e = intercept[ + org.apache.spark.sql.streaming.StreamingQueryException] { + query.processAllAvailable() + } + assert(e.getCause + .isInstanceOf[SparkUnsupportedOperationException]) + assert(e.getCause + .asInstanceOf[SparkUnsupportedOperationException] + .getCondition == + "UNSUPPORTED_STREAMING_SCHEMA_EVOLUTION.NOT_V2_TABLE") + } finally { + query.stop() + } + } + } + + test("streaming write with type widening") { + withTable(tableIdent) { + withTempDir { checkpointDir => + sql(s"CREATE TABLE $tableIdent (id INT, value INT)") + + val input = MemoryStream[(Int, Long)] + val df = input.toDF().toDF("id", "value") + + val query = df.writeStream + .withSchemaEvolution() + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Append()) + .toTable(tableIdent) + + try { + input.addData((1, 100L), (2, 200L)) + query.processAllAvailable() + } finally { + query.stop() + } + + val result = spark.table(tableIdent) + checkAnswer(result, Seq(Row(1, 100L), Row(2, 200L))) + assert(result.schema == StructType(Seq( + StructField("id", IntegerType), + StructField("value", LongType)))) + } + } + } +}