diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncCommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncCommitLog.scala index 7a6c26b249e96..d13affb82dbbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncCommitLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncCommitLog.scala @@ -53,7 +53,7 @@ class AsyncCommitLog( * the async write of the batch is completed. Future may also be completed exceptionally * to indicate some write error. */ - def addAsync(batchId: Long, metadata: CommitMetadata): CompletableFuture[Long] = { + def addAsync(batchId: Long, metadata: CommitMetadataBase): CompletableFuture[Long] = { require(metadata != null, "'null' metadata cannot be written to a metadata log") val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output => serialize(metadata, output) @@ -77,7 +77,7 @@ class AsyncCommitLog( * @param metadata metadata of batch to write * @return true if operation is successful otherwise false. */ - def addInMemory(batchId: Long, metadata: CommitMetadata): Boolean = { + def addInMemory(batchId: Long, metadata: CommitMetadataBase): Boolean = { if (batchCache.containsKey(batchId)) { false } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/CommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/CommitLog.scala index b73020b6060c6..820aecf70d0ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/CommitLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/CommitLog.scala @@ -26,6 +26,7 @@ import org.json4s.{Formats, NoTypeHints} import org.json4s.jackson.Serialization import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf /** @@ -50,39 +51,119 @@ class CommitLog( sparkSession: SparkSession, path: String, readOnly: Boolean = false) - extends HDFSMetadataLog[CommitMetadata](sparkSession, path, readOnly) { + extends HDFSMetadataLog[CommitMetadataBase](sparkSession, path, readOnly) { import CommitLog._ - private val VERSION: Int = sparkSession.conf.get( + // The configured commit log format version. Used as the default version when callers + // construct metadata through [[createMetadata]]. + private[sql] val defaultVersion: Int = sparkSession.conf.get( SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key).toInt - override protected[sql] def deserialize(in: InputStream): CommitMetadata = { - // called inside a try-finally where the underlying stream is closed in the caller - val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() - if (!lines.hasNext) { - throw new IllegalStateException("Incomplete log file in the offset commit log") - } - // TODO [SPARK-49462] This validation should be relaxed for a stateless query. - // TODO [SPARK-50653] This validation should be relaxed to support reading - // a V1 log file when VERSION is V2 - validateVersionExactMatch(lines.next().trim, VERSION) - val metadataJson = if (lines.hasNext) lines.next() else EMPTY_JSON - CommitMetadata(metadataJson) + override protected[sql] def deserialize(in: InputStream): CommitMetadataBase = { + CommitLog.readCommitMetadata(in) } - override protected[sql] def serialize(metadata: CommitMetadata, out: OutputStream): Unit = { + override protected[sql] def serialize(metadata: CommitMetadataBase, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller - out.write(s"v${VERSION}".getBytes(UTF_8)) + out.write(s"v${metadata.version}".getBytes(UTF_8)) out.write('\n') // write metadata out.write(metadata.json.getBytes(UTF_8)) } + + /** + * Factory for creating a [[CommitMetadataBase]] for the requested wire format version. + * Defaults to the version configured via [[SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION]]. + */ + def createMetadata( + nextBatchWatermarkMs: Long = 0, + stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None, + commitLogFormatVersion: Int = defaultVersion): CommitMetadataBase = { + commitLogFormatVersion match { + case VERSION_2 => + CommitMetadataV2(nextBatchWatermarkMs, stateUniqueIds) + case VERSION_1 => + // VERSION_1 cannot persist stateUniqueIds; withStateUniqueIds enforces this invariant + // (it throws if stateUniqueIds is non-empty). + CommitMetadata(nextBatchWatermarkMs).withStateUniqueIds(stateUniqueIds) + case v => + throw QueryExecutionErrors.logVersionGreaterThanSupported(v, CommitLog.MAX_VERSION) + } + } } object CommitLog { private val EMPTY_JSON = "{}" + val VERSION_1 = 1 + val VERSION_2 = 2 + val MAX_VERSION: Int = VERSION_2 + + /** + * Reads a single commit log entry and dispatches to the matching + * [[CommitMetadataBase]] subclass based on the wire format version recorded in the file. + */ + private[spark] def readCommitMetadata(in: InputStream): CommitMetadataBase = { + val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() + if (!lines.hasNext) { + throw new IllegalStateException("Incomplete log file in the offset commit log") + } + val version = MetadataVersionUtil.validateVersion(lines.next().trim, MAX_VERSION) + val metadataJson = if (lines.hasNext) lines.next() else EMPTY_JSON + version match { + case VERSION_2 => CommitMetadataV2(metadataJson) + case VERSION_1 => CommitMetadata(metadataJson) + case v => throw QueryExecutionErrors.logVersionGreaterThanSupported(v, MAX_VERSION) + } + } +} + +/** + * Base trait for commit log metadata. Concrete subclasses correspond to wire format versions + * and override [[version]] accordingly. + */ +trait CommitMetadataBase extends Serializable { + def version: Int + def nextBatchWatermarkMs: Long + def stateUniqueIds: Option[Map[Long, Array[Array[String]]]] + + /** + * Returns a copy of this metadata with the given state store unique ids, preserving the + * concrete subclass and all of its other fields. Deriving a new commit from an existing one + * should go through this method (rather than reconstructing via [[CommitLog.createMetadata]]) + * so that version-specific fields are not silently dropped when new metadata versions are + * introduced. + */ + def withStateUniqueIds( + stateUniqueIds: Option[Map[Long, Array[Array[String]]]]): CommitMetadataBase + + def json: String = Serialization.write(this)(CommitMetadata.format) +} + +/** + * Commit log metadata for [[CommitLog.VERSION_1]]. Records the watermark for the next batch only. + * + * @param nextBatchWatermarkMs The watermark of the next batch. + */ +case class CommitMetadata( + nextBatchWatermarkMs: Long = 0) extends CommitMetadataBase { + override def version: Int = CommitLog.VERSION_1 + override def stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None + + override def withStateUniqueIds( + stateUniqueIds: Option[Map[Long, Array[Array[String]]]]): CommitMetadata = { + require(stateUniqueIds.forall(_.isEmpty), + s"stateUniqueIds cannot be set for commit log format version ${CommitLog.VERSION_1}; " + + s"use version ${CommitLog.VERSION_2} to persist state store checkpoint ids.") + this + } +} + +object CommitMetadata { + implicit val format: Formats = Serialization.formats(NoTypeHints) + + def apply(json: String): CommitMetadata = Serialization.read[CommitMetadata](json) } /** @@ -104,19 +185,23 @@ object CommitLog { * +--- ...... * In the commit log, in addition to nextBatchWatermarkMs, we also store the unique ids of the * state store files. + * * @param nextBatchWatermarkMs The watermark of the next batch. * @param stateUniqueIds Map[Long, Array[Array[String]]] of map * OperatorId -> (partitionID -> array of uniqueID) */ - -case class CommitMetadata( +case class CommitMetadataV2( nextBatchWatermarkMs: Long = 0, - stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None) { - def json: String = Serialization.write(this)(CommitMetadata.format) + stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None) extends CommitMetadataBase { + override def version: Int = CommitLog.VERSION_2 + + override def withStateUniqueIds( + stateUniqueIds: Option[Map[Long, Array[Array[String]]]]): CommitMetadataV2 = + copy(stateUniqueIds = stateUniqueIds) } -object CommitMetadata { - implicit val format: Formats = Serialization.formats(NoTypeHints) +object CommitMetadataV2 { + import CommitMetadata.format - def apply(json: String): CommitMetadata = Serialization.read[CommitMetadata](json) + def apply(json: String): CommitMetadataV2 = Serialization.read[CommitMetadataV2](json) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala index 84f0373ca5d48..94143799a8c41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala @@ -46,7 +46,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, RealTimeStreamScanExec, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, Offset, OneTimeTrigger, ProcessingTimeTrigger, RealTimeModeAllowlist, RealTimeTrigger, Sink, Source, StreamingQueryPlanTraverseHelper} -import org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, CommitMetadata, OffsetSeqBase, OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataV2} +import org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, OffsetSeqBase, OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataV2} import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, StateStoreWriter} import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE} import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchSink, WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1} @@ -1464,7 +1464,9 @@ class MicroBatchExecution( None } if (!commitLog.add(execCtx.batchId, - CommitMetadata(watermarkTracker.currentWatermark, stateStoreCkptId))) { + commitLog.createMetadata( + nextBatchWatermarkMs = watermarkTracker.currentWatermark, + stateUniqueIds = stateStoreCkptId))) { throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala index 1491d26989062..dc13fa1030a0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala @@ -294,7 +294,9 @@ class OfflineStateRepartitionRunner( lastCommittedBatchId: Long, opIdToStateStoreCkptInfo: Option[Map[Long, Array[Array[String]]]]): Unit = { val latestCommit = checkpointMetadata.commitLog.get(lastCommittedBatchId).get - val commitMetadata = latestCommit.copy(stateUniqueIds = opIdToStateStoreCkptInfo) + // Derive the new commit from the latest one so version-specific fields are preserved and the + // wire format version stays consistent with the source checkpoint. + val commitMetadata = latestCommit.withStateUniqueIds(opIdToStateStoreCkptInfo) if (!checkpointMetadata.commitLog.add(newBatchId, commitMetadata)) { throw QueryExecutionErrors.concurrentStreamLogUpdate(newBatchId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala index fd890161caafd..546a9a6019647 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala @@ -22,7 +22,7 @@ import java.util.UUID import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkIllegalStateException, SparkThrowable, TaskContext} +import org.apache.spark.{SparkIllegalStateException, TaskContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys._ @@ -376,27 +376,19 @@ class StateRewriter( } private def verifyCheckpointFormatVersion(): Unit = { - // Verify checkpoint version in sqlConf based on commitLog for readCheckpoint - // in case user forgot to set STATE_STORE_CHECKPOINT_FORMAT_VERSION. - // Using read batch commit since the latest commit could be a skipped batch. - // If SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION is wrong, readCheckpoint.commitLog - // will throw an exception, and we will propagate this exception upstream. - // This prevents the StateRewriter from failing to write the correct state files - try { - readCheckpoint.commitLog.get(readBatchId) - } catch { - case e: IllegalStateException if e.getCause != null && - e.getCause.isInstanceOf[SparkThrowable] => - val sparkThrowable = e.getCause.asInstanceOf[SparkThrowable] - if (sparkThrowable.getCondition == "INVALID_LOG_VERSION.EXACT_MATCH_VERSION") { - val params = sparkThrowable.getMessageParameters - val expectedVersion = params.get("version") - val actualVersion = params.get("matchVersion") - throw StateRewriterErrors.stateCheckpointFormatVersionMismatchError( - checkpointLocationForRead, expectedVersion, actualVersion) - } - throw e + // Verify checkpoint version in sqlConf matches the version recorded in the read commit log, + // in case the user forgot to set STATE_STORE_CHECKPOINT_FORMAT_VERSION. This prevents the + // StateRewriter from writing state files in a format that disagrees with the source + // checkpoint. Using the read batch commit since the latest commit could be a skipped batch. + readCheckpoint.commitLog.get(readBatchId).foreach { metadata => + val configuredVersion = readCheckpoint.commitLog.defaultVersion + if (metadata.version != configuredVersion) { + throw StateRewriterErrors.stateCheckpointFormatVersionMismatchError( + checkpointLocationForRead, + expectedVersion = metadata.version.toString, + actualVersion = configuredVersion.toString) } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala index bae78f0b4762f..4e9f6cca2ffc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.scalatest.Assertions import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata} +import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata, CommitMetadataV2} import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamExecution} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.functions.{col, window} @@ -237,11 +237,11 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB new File(tempDir.getAbsolutePath, "commits").getAbsolutePath) // Start version: treated as v1 (no operator unique ids) - val startMetadata = CommitMetadata(0, None) + val startMetadata = CommitMetadata(0) assert(commitLog.add(0, startMetadata)) // End version: treated as v2 (operator 0 has unique ids) - val endMetadata = CommitMetadata(0, + val endMetadata = CommitMetadataV2(0, Some(Map[Long, Array[Array[String]]](0L -> Array(Array("uid"))))) assert(commitLog.add(1, endMetadata)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index 2def79828fac1..4a2a454077a7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -23,7 +23,7 @@ import java.util.UUID import org.apache.hadoop.conf.Configuration import org.scalatest.Assertions -import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} +import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, Row} import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow} @@ -589,8 +589,6 @@ class RocksDBWithCheckpointV2StateDataSourceReaderSuite extends StateDataSourceR override protected def newStateStoreProvider(): RocksDBStateStoreProvider = new RocksDBStateStoreProvider - import testImplicits._ - override def beforeAll(): Unit = { super.beforeAll() spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2) @@ -600,34 +598,57 @@ class RocksDBWithCheckpointV2StateDataSourceReaderSuite extends StateDataSourceR "true") } - // TODO: Remove this test once we allow migrations from checkpoint v1 to v2 - test("reading checkpoint v2 store with version 1 should fail") { - withTempDir { tmpDir => - val inputData = MemoryStream[(Int, Long)] - val query = getStreamStreamJoinQuery(inputData) - testStream(query)( - StartStream(checkpointLocation = tmpDir.getCanonicalPath), - AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), - ProcessAllAvailable(), - Execute { _ => Thread.sleep(2000) }, - StopStream - ) + // Expected state after runLargeDataStreamingAggregationQuery, read from batch 2 / operator 0. + private val expectedLargeAggregationState: Seq[Row] = Seq( + Row(0, 5, 60, 30, 0), Row(1, 5, 65, 31, 1), Row(2, 5, 70, 32, 2), + Row(3, 4, 72, 33, 3), Row(4, 4, 76, 34, 4), Row(5, 4, 80, 35, 5), + Row(6, 4, 84, 36, 6), Row(7, 4, 88, 37, 7), Row(8, 4, 92, 38, 8), + Row(9, 4, 96, 39, 9)) + + private def readLargeAggregationState(checkpointDir: String): DataFrame = + spark.read.format("statestore") + .option(StateSourceOptions.PATH, checkpointDir) + .option(StateSourceOptions.BATCH_ID, 2) + .option(StateSourceOptions.OPERATOR_ID, 0) + .load() + .selectExpr("key.groupKey AS key_groupKey", "value.count AS value_cnt", + "value.sum AS value_sum", "value.max AS value_max", "value.min AS value_min") + // SPARK-56970: The commit log wire format version is now discovered from the file header + // rather than required to match STATE_STORE_CHECKPOINT_FORMAT_VERSION. As a result a V1 commit + // log can be read under a V2-configured session (and vice versa). Note this only applies to the + // commit log layer; reading a V2 state store still requires version 2 to be configured because + // the state store files are named with checkpoint unique ids. + test("SPARK-56970: reading a v1 checkpoint with commit log version 2 configured succeeds") { + withTempDir { tempDir => + // Override the suite default to write a V1 checkpoint (no checkpoint unique ids). withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "1") { - // Verify reading state throws error when reading checkpoint v2 with version 1 - val exc = intercept[IllegalStateException] { - val stateDf = spark.read.format("statestore") - .option(StateSourceOptions.BATCH_ID, 0) - .option(StateSourceOptions.OPERATOR_ID, 0) - .load(tmpDir.getCanonicalPath) - stateDf.collect() - } + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + } + + // The suite default reads with version 2 configured; the V1 commit log must still be read. + checkAnswer( + readLargeAggregationState(tempDir.getAbsolutePath), expectedLargeAggregationState) + } + } - checkError(exc.getCause.asInstanceOf[SparkThrowable], - "INVALID_LOG_VERSION.EXACT_MATCH_VERSION", "KD002", - Map( - "version" -> "2", - "matchVersion" -> "1")) + test("SPARK-56970: reading a v2 checkpoint with commit log version 1 configured fails on the " + + "state store, not the commit log") { + withTempDir { tempDir => + // The suite configures commit log format version 2, so this writes a V2 checkpoint whose + // state store files are named with checkpoint unique ids. + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "1") { + // The commit log now deserializes across versions, so this no longer fails with + // INVALID_LOG_VERSION at the commit-log layer. Reading the V2 state store itself still + // requires version 2 to be configured: with version 1 the reader looks for non-unique + // state file names and cannot locate the unique-id-named files. + val ex = intercept[SparkException] { + readLargeAggregationState(tempDir.getAbsolutePath).collect() + } + assert(ex.getMessage.contains("CANNOT_LOAD_STATE_STORE") || + Option(ex.getCause).map(_.getMessage).exists(_.contains("CANNOT_LOAD_STATE_STORE"))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala index be7874e806cd8..22d0af0a77fd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala @@ -99,7 +99,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase // Commit to commitLog with checkpoint IDs val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get - val commitMetadata = latestCommit.copy(stateUniqueIds = checkpointInfos) + val commitMetadata = latestCommit.withStateUniqueIds(checkpointInfos) targetCheckpointMetadata.commitLog.add(writeBatchId, commitMetadata) val versionToCheck = writeBatchId + 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/CommitLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/CommitLogSuite.scala index 332de78e7cbf9..e4becd0571680 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/CommitLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/CommitLogSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming import java.io.{ByteArrayInputStream, FileInputStream, FileOutputStream} import java.nio.file.Path -import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata} +import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata, CommitMetadataBase, CommitMetadataV2} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -62,7 +62,7 @@ class CommitLogSuite extends SharedSparkSession { ) } - private def testSerde(commitMetadata: CommitMetadata, path: Path): Unit = { + private def testSerde(commitMetadata: CommitMetadataBase, path: Path): Unit = { if (regenerateGoldenFiles) { val commitLog = new CommitLog(spark, path.toString) val outputStream = new FileOutputStream(path.resolve("testCommitLog").toFile) @@ -102,19 +102,21 @@ class CommitLogSuite extends SharedSparkSession { 0L -> Array(Array("unique_id1", "unique_id2"), Array("unique_id3", "unique_id4")), 1L -> Array(Array("unique_id5", "unique_id6"), Array("unique_id7", "unique_id8")) ) - val testMetadataV2 = CommitMetadata(0, Some(testStateUniqueIds)) + val testMetadataV2 = CommitMetadataV2(0, Some(testStateUniqueIds)) testSerde(testMetadataV2, testCommitLogV2FilePath) } } test("Basic Commit Log V2 SerDe - empty stateUniqueIds") { withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2") { - val testMetadataV2 = CommitMetadata(0, Some(Map[Long, Array[Array[String]]]())) + val testMetadataV2 = CommitMetadataV2(0, Some(Map[Long, Array[Array[String]]]())) testSerde(testMetadataV2, testCommitLogV2FilePathEmptyUniqueId) } } - // Old metadata structure with no state unique ids should not affect the deserialization + // SPARK-50653: When the configured commit log version is V2, a V1 file on disk should still + // deserialize successfully into a V1 [[CommitMetadata]] because the wire format version is now + // discovered from the file header rather than enforced to match the conf. test("Cross-version V1 SerDe") { withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2") { val commitlogV1 = """v1 @@ -122,18 +124,41 @@ class CommitLogSuite extends SharedSparkSession { val inputStream: ByteArrayInputStream = new ByteArrayInputStream(commitlogV1.getBytes("UTF-8")) - // TODO [SPARK-50653]: Uncomment the below when v2 -> v1 backward compatibility is added - // val commitMetadata: CommitMetadata = new CommitLog( - // spark, testCommitLogV1FilePath.toString).deserialize(inputStream) - // assert(commitMetadata.nextBatchWatermarkMs === 233) - // assert(commitMetadata.stateUniqueIds === Map.empty) + val commitMetadata = new CommitLog( + spark, testCommitLogV1FilePath.toString).deserialize(inputStream) + assert(commitMetadata.version === CommitLog.VERSION_1) + assert(commitMetadata.nextBatchWatermarkMs === 233) + assert(commitMetadata.stateUniqueIds.isEmpty) + } + } + + test("SPARK-56970: creating a V1 commit with stateUniqueIds should fail") { + withTempDir { tmpDir => + val commitLog = new CommitLog(spark, tmpDir.getCanonicalPath) + val stateUniqueIds: Map[Long, Array[Array[String]]] = + Map(0L -> Array(Array("unique_id1", "unique_id2"))) + + // Through the createMetadata factory with an explicit V1 format version. + val e1 = intercept[IllegalArgumentException] { + commitLog.createMetadata( + nextBatchWatermarkMs = 1, + stateUniqueIds = Some(stateUniqueIds), + commitLogFormatVersion = CommitLog.VERSION_1) + } + assert(e1.getMessage.contains("stateUniqueIds cannot be set")) - // TODO [SPARK-50653]: remove the below when v2 -> v1 backward compatibility is added - val e = intercept[IllegalStateException] { - new CommitLog(spark, testCommitLogV1FilePath.toString).deserialize(inputStream) + // Directly through withStateUniqueIds on a V1 metadata. + val e2 = intercept[IllegalArgumentException] { + CommitMetadata(1).withStateUniqueIds(Some(stateUniqueIds)) } + assert(e2.getMessage.contains("stateUniqueIds cannot be set")) - assert (e.getMessage.contains("only supported log version")) + // None and an empty map are allowed for V1 (no unique ids to persist). + assert(CommitMetadata(1).withStateUniqueIds(None).stateUniqueIds.isEmpty) + assert(commitLog.createMetadata( + nextBatchWatermarkMs = 1, + stateUniqueIds = Some(Map.empty[Long, Array[Array[String]]]), + commitLogFormatVersion = CommitLog.VERSION_1).version === CommitLog.VERSION_1) } } }