Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class AsyncCommitLog(
* the async write of the batch is completed. Future may also be completed exceptionally
* to indicate some write error.
*/
def addAsync(batchId: Long, metadata: CommitMetadata): CompletableFuture[Long] = {
def addAsync(batchId: Long, metadata: CommitMetadataBase): CompletableFuture[Long] = {
require(metadata != null, "'null' metadata cannot be written to a metadata log")
val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
serialize(metadata, output)
Expand All @@ -77,7 +77,7 @@ class AsyncCommitLog(
* @param metadata metadata of batch to write
* @return true if operation is successful otherwise false.
*/
def addInMemory(batchId: Long, metadata: CommitMetadata): Boolean = {
def addInMemory(batchId: Long, metadata: CommitMetadataBase): Boolean = {
if (batchCache.containsKey(batchId)) {
false
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.json4s.{Formats, NoTypeHints}
import org.json4s.jackson.Serialization

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf

/**
Expand All @@ -50,39 +51,119 @@ class CommitLog(
sparkSession: SparkSession,
path: String,
readOnly: Boolean = false)
extends HDFSMetadataLog[CommitMetadata](sparkSession, path, readOnly) {
extends HDFSMetadataLog[CommitMetadataBase](sparkSession, path, readOnly) {

import CommitLog._

private val VERSION: Int = sparkSession.conf.get(
// The configured commit log format version. Used as the default version when callers
// construct metadata through [[createMetadata]].
private[sql] val defaultVersion: Int = sparkSession.conf.get(
SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key).toInt

override protected[sql] def deserialize(in: InputStream): CommitMetadata = {
// called inside a try-finally where the underlying stream is closed in the caller
val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()
if (!lines.hasNext) {
throw new IllegalStateException("Incomplete log file in the offset commit log")
}
// TODO [SPARK-49462] This validation should be relaxed for a stateless query.
// TODO [SPARK-50653] This validation should be relaxed to support reading
// a V1 log file when VERSION is V2
validateVersionExactMatch(lines.next().trim, VERSION)
val metadataJson = if (lines.hasNext) lines.next() else EMPTY_JSON
CommitMetadata(metadataJson)
override protected[sql] def deserialize(in: InputStream): CommitMetadataBase = {
CommitLog.readCommitMetadata(in)
}

override protected[sql] def serialize(metadata: CommitMetadata, out: OutputStream): Unit = {
override protected[sql] def serialize(metadata: CommitMetadataBase, out: OutputStream): Unit = {
// called inside a try-finally where the underlying stream is closed in the caller
out.write(s"v${VERSION}".getBytes(UTF_8))
out.write(s"v${metadata.version}".getBytes(UTF_8))
out.write('\n')

// write metadata
out.write(metadata.json.getBytes(UTF_8))
}

/**
* Factory for creating a [[CommitMetadataBase]] for the requested wire format version.
* Defaults to the version configured via [[SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION]].
*/
def createMetadata(
nextBatchWatermarkMs: Long = 0,
stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None,
commitLogFormatVersion: Int = defaultVersion): CommitMetadataBase = {
commitLogFormatVersion match {
case VERSION_2 =>
CommitMetadataV2(nextBatchWatermarkMs, stateUniqueIds)
case VERSION_1 =>
// VERSION_1 cannot persist stateUniqueIds; withStateUniqueIds enforces this invariant
// (it throws if stateUniqueIds is non-empty).
CommitMetadata(nextBatchWatermarkMs).withStateUniqueIds(stateUniqueIds)
case v =>
throw QueryExecutionErrors.logVersionGreaterThanSupported(v, CommitLog.MAX_VERSION)
}
}
}

object CommitLog {
private val EMPTY_JSON = "{}"
val VERSION_1 = 1
val VERSION_2 = 2
val MAX_VERSION: Int = VERSION_2

/**
* Reads a single commit log entry and dispatches to the matching
* [[CommitMetadataBase]] subclass based on the wire format version recorded in the file.
*/
private[spark] def readCommitMetadata(in: InputStream): CommitMetadataBase = {
val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()
if (!lines.hasNext) {
throw new IllegalStateException("Incomplete log file in the offset commit log")
}
val version = MetadataVersionUtil.validateVersion(lines.next().trim, MAX_VERSION)
val metadataJson = if (lines.hasNext) lines.next() else EMPTY_JSON
version match {
case VERSION_2 => CommitMetadataV2(metadataJson)
case VERSION_1 => CommitMetadata(metadataJson)
case v => throw QueryExecutionErrors.logVersionGreaterThanSupported(v, MAX_VERSION)
}
}
}

/**
* Base trait for commit log metadata. Concrete subclasses correspond to wire format versions
* and override [[version]] accordingly.
*/
trait CommitMetadataBase extends Serializable {
def version: Int
def nextBatchWatermarkMs: Long
def stateUniqueIds: Option[Map[Long, Array[Array[String]]]]

/**
* Returns a copy of this metadata with the given state store unique ids, preserving the
* concrete subclass and all of its other fields. Deriving a new commit from an existing one
* should go through this method (rather than reconstructing via [[CommitLog.createMetadata]])
* so that version-specific fields are not silently dropped when new metadata versions are
* introduced.
*/
def withStateUniqueIds(
stateUniqueIds: Option[Map[Long, Array[Array[String]]]]): CommitMetadataBase

def json: String = Serialization.write(this)(CommitMetadata.format)
}

/**
* Commit log metadata for [[CommitLog.VERSION_1]]. Records the watermark for the next batch only.
*
* @param nextBatchWatermarkMs The watermark of the next batch.
*/
case class CommitMetadata(
nextBatchWatermarkMs: Long = 0) extends CommitMetadataBase {
override def version: Int = CommitLog.VERSION_1
override def stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None

override def withStateUniqueIds(
stateUniqueIds: Option[Map[Long, Array[Array[String]]]]): CommitMetadata = {
require(stateUniqueIds.forall(_.isEmpty),
s"stateUniqueIds cannot be set for commit log format version ${CommitLog.VERSION_1}; " +
s"use version ${CommitLog.VERSION_2} to persist state store checkpoint ids.")
this
}
}

object CommitMetadata {
implicit val format: Formats = Serialization.formats(NoTypeHints)

def apply(json: String): CommitMetadata = Serialization.read[CommitMetadata](json)
}

/**
Expand All @@ -104,19 +185,23 @@ object CommitLog {
* +--- ......
* In the commit log, in addition to nextBatchWatermarkMs, we also store the unique ids of the
* state store files.
*
* @param nextBatchWatermarkMs The watermark of the next batch.
* @param stateUniqueIds Map[Long, Array[Array[String]]] of map
* OperatorId -> (partitionID -> array of uniqueID)
*/

case class CommitMetadata(
case class CommitMetadataV2(
nextBatchWatermarkMs: Long = 0,
stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None) {
def json: String = Serialization.write(this)(CommitMetadata.format)
stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None) extends CommitMetadataBase {
override def version: Int = CommitLog.VERSION_2

override def withStateUniqueIds(
stateUniqueIds: Option[Map[Long, Array[Array[String]]]]): CommitMetadataV2 =
copy(stateUniqueIds = stateUniqueIds)
}

object CommitMetadata {
implicit val format: Formats = Serialization.formats(NoTypeHints)
object CommitMetadataV2 {
import CommitMetadata.format

def apply(json: String): CommitMetadata = Serialization.read[CommitMetadata](json)
def apply(json: String): CommitMetadataV2 = Serialization.read[CommitMetadataV2](json)
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, RealTimeStreamScanExec, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec}
import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, Offset, OneTimeTrigger, ProcessingTimeTrigger, RealTimeModeAllowlist, RealTimeTrigger, Sink, Source, StreamingQueryPlanTraverseHelper}
import org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, CommitMetadata, OffsetSeqBase, OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataV2}
import org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, OffsetSeqBase, OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataV2}
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, StateStoreWriter}
import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE}
import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchSink, WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1}
Expand Down Expand Up @@ -1464,7 +1464,9 @@ class MicroBatchExecution(
None
}
if (!commitLog.add(execCtx.batchId,
CommitMetadata(watermarkTracker.currentWatermark, stateStoreCkptId))) {
commitLog.createMetadata(
nextBatchWatermarkMs = watermarkTracker.currentWatermark,
stateUniqueIds = stateStoreCkptId))) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,9 @@ class OfflineStateRepartitionRunner(
lastCommittedBatchId: Long,
opIdToStateStoreCkptInfo: Option[Map[Long, Array[Array[String]]]]): Unit = {
val latestCommit = checkpointMetadata.commitLog.get(lastCommittedBatchId).get
val commitMetadata = latestCommit.copy(stateUniqueIds = opIdToStateStoreCkptInfo)
// Derive the new commit from the latest one so version-specific fields are preserved and the
// wire format version stays consistent with the source checkpoint.
val commitMetadata = latestCommit.withStateUniqueIds(opIdToStateStoreCkptInfo)

if (!checkpointMetadata.commitLog.add(newBatchId, commitMetadata)) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(newBatchId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.UUID
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkIllegalStateException, SparkThrowable, TaskContext}
import org.apache.spark.{SparkIllegalStateException, TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys._
Expand Down Expand Up @@ -376,27 +376,19 @@ class StateRewriter(
}

private def verifyCheckpointFormatVersion(): Unit = {
// Verify checkpoint version in sqlConf based on commitLog for readCheckpoint
// in case user forgot to set STATE_STORE_CHECKPOINT_FORMAT_VERSION.
// Using read batch commit since the latest commit could be a skipped batch.
// If SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION is wrong, readCheckpoint.commitLog
// will throw an exception, and we will propagate this exception upstream.
// This prevents the StateRewriter from failing to write the correct state files
try {
readCheckpoint.commitLog.get(readBatchId)
} catch {
case e: IllegalStateException if e.getCause != null &&
e.getCause.isInstanceOf[SparkThrowable] =>
val sparkThrowable = e.getCause.asInstanceOf[SparkThrowable]
if (sparkThrowable.getCondition == "INVALID_LOG_VERSION.EXACT_MATCH_VERSION") {
val params = sparkThrowable.getMessageParameters
val expectedVersion = params.get("version")
val actualVersion = params.get("matchVersion")
throw StateRewriterErrors.stateCheckpointFormatVersionMismatchError(
checkpointLocationForRead, expectedVersion, actualVersion)
}
throw e
// Verify checkpoint version in sqlConf matches the version recorded in the read commit log,
// in case the user forgot to set STATE_STORE_CHECKPOINT_FORMAT_VERSION. This prevents the
// StateRewriter from writing state files in a format that disagrees with the source
// checkpoint. Using the read batch commit since the latest commit could be a skipped batch.
readCheckpoint.commitLog.get(readBatchId).foreach { metadata =>
val configuredVersion = readCheckpoint.commitLog.defaultVersion
if (metadata.version != configuredVersion) {
throw StateRewriterErrors.stateCheckpointFormatVersionMismatchError(
checkpointLocationForRead,
expectedVersion = metadata.version.toString,
actualVersion = configuredVersion.toString)
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration
import org.scalatest.Assertions

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata}
import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata, CommitMetadataV2}
import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamExecution}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.functions.{col, window}
Expand Down Expand Up @@ -237,11 +237,11 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB
new File(tempDir.getAbsolutePath, "commits").getAbsolutePath)

// Start version: treated as v1 (no operator unique ids)
val startMetadata = CommitMetadata(0, None)
val startMetadata = CommitMetadata(0)
assert(commitLog.add(0, startMetadata))

// End version: treated as v2 (operator 0 has unique ids)
val endMetadata = CommitMetadata(0,
val endMetadata = CommitMetadataV2(0,
Some(Map[Long, Array[Array[String]]](0L -> Array(Array("uid")))))
assert(commitLog.add(1, endMetadata))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.UUID
import org.apache.hadoop.conf.Configuration
import org.scalatest.Assertions

import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException}
import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, Row}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow}
Expand Down Expand Up @@ -589,8 +589,6 @@ class RocksDBWithCheckpointV2StateDataSourceReaderSuite extends StateDataSourceR
override protected def newStateStoreProvider(): RocksDBStateStoreProvider =
new RocksDBStateStoreProvider

import testImplicits._

override def beforeAll(): Unit = {
super.beforeAll()
spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2)
Expand All @@ -600,34 +598,57 @@ class RocksDBWithCheckpointV2StateDataSourceReaderSuite extends StateDataSourceR
"true")
}

// TODO: Remove this test once we allow migrations from checkpoint v1 to v2
test("reading checkpoint v2 store with version 1 should fail") {
withTempDir { tmpDir =>
val inputData = MemoryStream[(Int, Long)]
val query = getStreamStreamJoinQuery(inputData)
testStream(query)(
StartStream(checkpointLocation = tmpDir.getCanonicalPath),
AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)),
ProcessAllAvailable(),
Execute { _ => Thread.sleep(2000) },
StopStream
)
// Expected state after runLargeDataStreamingAggregationQuery, read from batch 2 / operator 0.
private val expectedLargeAggregationState: Seq[Row] = Seq(
Row(0, 5, 60, 30, 0), Row(1, 5, 65, 31, 1), Row(2, 5, 70, 32, 2),
Row(3, 4, 72, 33, 3), Row(4, 4, 76, 34, 4), Row(5, 4, 80, 35, 5),
Row(6, 4, 84, 36, 6), Row(7, 4, 88, 37, 7), Row(8, 4, 92, 38, 8),
Row(9, 4, 96, 39, 9))

private def readLargeAggregationState(checkpointDir: String): DataFrame =
spark.read.format("statestore")
.option(StateSourceOptions.PATH, checkpointDir)
.option(StateSourceOptions.BATCH_ID, 2)
.option(StateSourceOptions.OPERATOR_ID, 0)
.load()
.selectExpr("key.groupKey AS key_groupKey", "value.count AS value_cnt",
"value.sum AS value_sum", "value.max AS value_max", "value.min AS value_min")

// SPARK-56970: The commit log wire format version is now discovered from the file header
// rather than required to match STATE_STORE_CHECKPOINT_FORMAT_VERSION. As a result a V1 commit
// log can be read under a V2-configured session (and vice versa). Note this only applies to the
// commit log layer; reading a V2 state store still requires version 2 to be configured because
// the state store files are named with checkpoint unique ids.
test("SPARK-56970: reading a v1 checkpoint with commit log version 2 configured succeeds") {
withTempDir { tempDir =>
// Override the suite default to write a V1 checkpoint (no checkpoint unique ids).
withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "1") {
// Verify reading state throws error when reading checkpoint v2 with version 1
val exc = intercept[IllegalStateException] {
val stateDf = spark.read.format("statestore")
.option(StateSourceOptions.BATCH_ID, 0)
.option(StateSourceOptions.OPERATOR_ID, 0)
.load(tmpDir.getCanonicalPath)
stateDf.collect()
}
runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
}

// The suite default reads with version 2 configured; the V1 commit log must still be read.
checkAnswer(
readLargeAggregationState(tempDir.getAbsolutePath), expectedLargeAggregationState)
}
}

checkError(exc.getCause.asInstanceOf[SparkThrowable],
"INVALID_LOG_VERSION.EXACT_MATCH_VERSION", "KD002",
Map(
"version" -> "2",
"matchVersion" -> "1"))
test("SPARK-56970: reading a v2 checkpoint with commit log version 1 configured fails on the " +
"state store, not the commit log") {
withTempDir { tempDir =>
// The suite configures commit log format version 2, so this writes a V2 checkpoint whose
// state store files are named with checkpoint unique ids.
runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)

withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "1") {
// The commit log now deserializes across versions, so this no longer fails with
// INVALID_LOG_VERSION at the commit-log layer. Reading the V2 state store itself still
// requires version 2 to be configured: with version 1 the reader looks for non-unique
// state file names and cannot locate the unique-id-named files.
val ex = intercept[SparkException] {
readLargeAggregationState(tempDir.getAbsolutePath).collect()
}
assert(ex.getMessage.contains("CANNOT_LOAD_STATE_STORE") ||
Option(ex.getCause).map(_.getMessage).exists(_.contains("CANNOT_LOAD_STATE_STORE")))
}
}
}
Expand Down
Loading