From 137c25b6bf38a5da7ba0bd2d34449df6a6def96c Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 2 Apr 2026 11:08:25 +0900 Subject: [PATCH 01/10] WIP introduce scan to scope the iteration in State Store --- .../join/SymmetricHashJoinStateManager.scala | 47 ++-- .../execution/streaming/state/RocksDB.scala | 80 +++++++ .../state/RocksDBStateStoreProvider.scala | 57 +++++ .../streaming/state/StateStore.scala | 57 +++++ ...sDBStateStoreCheckpointFormatV2Suite.scala | 14 ++ ...cksDBTimestampEncoderOperationsSuite.scala | 225 ++++++++++++++++++ 6 files changed, 457 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index fc2a69312fe79..b0b019bac5860 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -27,7 +27,7 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX, STATE_STORE_ID} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GenericInternalRow, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, WatermarkSupport} @@ -647,17 +647,25 @@ class SymmetricHashJoinStateManagerV4( /** * Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp. - * Skips entries before minTs and stops iterating past maxTs (timestamps are sorted). + * + * When a bounded range is provided, leverages RocksDB's native seek and upper bound via + * [[StateStore.scanWithMultiValues]] to avoid reading entries outside the range. + * Falls back to [[StateStore.prefixScanWithMultiValues]] when the full range is requested. */ def getValuesInRange( key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = { val reusableGetValuesResult = new GetValuesResult() new NextIterator[GetValuesResult] { - private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName) + private val iter = if (minTs == Long.MinValue && maxTs == Long.MaxValue) { + stateStore.prefixScanWithMultiValues(key, colFamilyName) + } else { + val startKeyRow = createKeyRow(key, minTs) + val endKeyRow = createKeyRow(key, maxTs + 1) + stateStore.scanWithMultiValues(Some(startKeyRow), endKeyRow, colFamilyName) + } private var currentTs = -1L - private var pastUpperBound = false private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]() private def flushAccumulated(): GetValuesResult = { @@ -675,18 +683,13 @@ class SymmetricHashJoinStateManagerV4( @tailrec override protected def getNext(): GetValuesResult = { - if (pastUpperBound || !iter.hasNext) { + if (!iter.hasNext) { flushAccumulated() } else { val unsafeRowPair = iter.next() val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key) - if (ts > maxTs) { - pastUpperBound = true - getNext() - } else if (ts < minTs) { - getNext() - } else if (currentTs == -1L || currentTs == ts) { + if (currentTs == -1L || currentTs == ts) { currentTs = ts valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value) getNext() @@ -770,11 +773,17 @@ class SymmetricHashJoinStateManagerV4( stateStore.remove(createKeyRow(key, timestamp), colFamilyName) } + private lazy val dummyKeyRow: UnsafeRow = { + val projection = UnsafeProjection.create(keySchema) + projection(new GenericInternalRow(keySchema.length)).copy() + } + case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int) // NOTE: This assumes we consume the whole iterator to trigger completion. def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = { - val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName) + val endKeyRow = createKeyRow(dummyKeyRow, endTimestamp + 1) + val evictIterator = stateStore.scanWithMultiValues(None, endKeyRow, colFamilyName) new NextIterator[EvictedKeysResult]() { var currentKeyRow: UnsafeRow = null var currentEventTime: Long = -1L @@ -789,25 +798,19 @@ class SymmetricHashJoinStateManagerV4( val ts = TimestampKeyStateEncoder.extractTimestamp(kv.key) if (keyRow == currentKeyRow && ts == currentEventTime) { - // new value with same (key, ts) count += 1 } else if (ts > endTimestamp) { - // we found the timestamp beyond the range - we shouldn't continue further + // Safety check for boundary edge case: a small number of entries at exactly + // endTimestamp + 1 may leak through the upper bound because the encoded end key + // includes a join key suffix. isBeyondUpperBound = true - - // We don't need to construct the last (key, ts) into EvictedKeysResult - the code - // after loop will handle that if there is leftover. That said, we do not reset the - // current (key, ts) info here. } else if (currentKeyRow == null && currentEventTime == -1L) { - // first value to process currentKeyRow = keyRow.copy() currentEventTime = ts count = 1 } else { - // construct the last (key, ts) into EvictedKeysResult ret = EvictedKeysResult(currentKeyRow, currentEventTime, count) - // register the next (key, ts) to process currentKeyRow = keyRow.copy() currentEventTime = ts count = 1 @@ -817,10 +820,8 @@ class SymmetricHashJoinStateManagerV4( if (ret != null) { ret } else if (count > 0) { - // there is a final leftover (key, ts) to return ret = EvictedKeysResult(currentKeyRow, currentEventTime, count) - // we shouldn't continue further currentKeyRow = null currentEventTime = -1L count = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index a24a76269828f..ec879a44f213e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -1662,6 +1662,86 @@ class RocksDB( } } + /** + * Scan key-value pairs in the range [startKey, endKey). + * + * @param startKey None to seek to the beginning of the column family, + * or Some(key) to seek to the given start position (inclusive). + * @param endKey The exclusive upper bound for the scan (encoded key bytes). + * @param cfName The column family name. + * @return An iterator of ByteArrayPairs in the given range. + */ + def scan( + startKey: Option[Array[Byte]], + endKey: Array[Byte], + cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): NextIterator[ByteArrayPair] = { + updateMemoryUsageIfNeeded() + + val updatedEndKey = if (useColumnFamilies) { + encodeStateRowWithPrefix(endKey, cfName) + } else { + endKey + } + + val seekTarget = startKey match { + case Some(key) => + if (useColumnFamilies) encodeStateRowWithPrefix(key, cfName) else key + case None => + if (useColumnFamilies) encodeStateRowWithPrefix(Array.emptyByteArray, cfName) + else null + } + + val upperBoundSlice = new Slice(updatedEndKey) + val scanReadOptions = new ReadOptions() + scanReadOptions.setIterateUpperBound(upperBoundSlice) + + val iter = db.newIterator(scanReadOptions) + if (seekTarget != null) { + iter.seek(seekTarget) + } else { + iter.seekToFirst() + } + + def closeResources(): Unit = { + iter.close() + scanReadOptions.close() + upperBoundSlice.close() + } + + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => closeResources() } + } + + new NextIterator[ByteArrayPair] { + override protected def getNext(): ByteArrayPair = { + if (iter.isValid) { + val key = if (useColumnFamilies) { + decodeStateRowWithPrefix(iter.key)._1 + } else { + iter.key + } + + val value = if (conf.rowChecksumEnabled) { + KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum( + readVerifier, iter.key, iter.value, delimiterSize) + } else { + iter.value + } + + byteArrayPair.set(key, value) + iter.next() + byteArrayPair + } else { + finished = true + closeResources() + null + } + } + + override protected def close(): Unit = closeResources() + } + } + def release(): Unit = {} /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 8b023a0e9f9fc..7485d679d2333 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -549,6 +549,63 @@ private[sql] class RocksDBStateStoreProvider new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) } + override def scan( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + validateAndTransitionState(UPDATE) + verifyColFamilyOperations("scan", colFamilyName) + + val kvEncoder = keyValueEncoderMap.get(colFamilyName) + val encodedStartKey = startKey.map(kvEncoder._1.encodeKey) + val encodedEndKey = kvEncoder._1.encodeKey(endKey) + + val rowPair = new UnsafeRowPair() + val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName) + val iter = rocksDbIter.map { kv => + rowPair.withRows(kvEncoder._1.decodeKey(kv.key), + kvEncoder._2.decodeValue(kv.value)) + rowPair + } + + new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) + } + + override def scanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + validateAndTransitionState(UPDATE) + verifyColFamilyOperations("scanWithMultiValues", colFamilyName) + + val kvEncoder = keyValueEncoderMap.get(colFamilyName) + verify( + kvEncoder._2.supportsMultipleValuesPerKey, + "Multi-value iterator operation requires an encoder" + + " which supports multiple values for a single key") + + val encodedStartKey = startKey.map(kvEncoder._1.encodeKey) + val encodedEndKey = kvEncoder._1.encodeKey(endKey) + val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName) + + val rowPair = new UnsafeRowPair() + val iter = rocksDbIter.flatMap { kv => + val keyRow = kvEncoder._1.decodeKey(kv.key) + val valueRows = kvEncoder._2.decodeValues(kv.value) + valueRows.iterator.map { valueRow => + rowPair.withRows(keyRow, valueRow) + if (!isValidated && rowPair.value != null && !useColumnFamilies) { + StateStoreProvider.validateStateRowFormat( + rowPair.key, keySchema, rowPair.value, valueSchema, stateStoreId, storeConf) + isValidated = true + } + rowPair + } + } + + new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) + } + var checkpointInfo: Option[StateStoreCheckpointInfo] = None private var storedMetrics: Option[RocksDBMetrics] = None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 6e08c10476ce7..89b540ee1d178 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -183,6 +183,49 @@ trait ReadStateStore { prefixKey: UnsafeRow, colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair] + /** + * Scan key-value pairs in the range [startKey, endKey). + * + * @param startKey None to scan from the beginning of the column family, + * or Some(key) to seek to the given start position (inclusive). + * @param endKey The exclusive upper bound for the scan. + * @param colFamilyName The column family name. + * + * Callers must ensure the column family's key encoder produces lexicographically ordered + * bytes for the scan range to be meaningful (e.g., timestamp-based encoders or + * RangeKeyScanStateEncoder). + */ + def scan( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) + : StateStoreIterator[UnsafeRowPair] = { + throw StateStoreErrors.unsupportedOperationException("scan", "") + } + + /** + * Scan key-value pairs in the range [startKey, endKey), expanding multi-valued entries. + * + * @param startKey None to scan from the beginning of the column family, + * or Some(key) to seek to the given start position (inclusive). + * @param endKey The exclusive upper bound for the scan. + * @param colFamilyName The column family name. + * + * Callers must ensure the column family's key encoder produces lexicographically ordered + * bytes for the scan range to be meaningful (e.g., timestamp-based encoders or + * RangeKeyScanStateEncoder). + * + * It is expected to throw exception if Spark calls this method without setting + * multipleValuesPerKey as true for the column family. + */ + def scanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) + : StateStoreIterator[UnsafeRowPair] = { + throw StateStoreErrors.unsupportedOperationException("scanWithMultiValues", "") + } + /** Return an iterator containing all the key-value pairs in the StateStore. */ def iterator( colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair] @@ -411,6 +454,20 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore { store.prefixScanWithMultiValues(prefixKey, colFamilyName) } + override def scan( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + store.scan(startKey, endKey, colFamilyName) + } + + override def scanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + store.scanWithMultiValues(startKey, endKey, colFamilyName) + } + override def iteratorWithMultiValues( colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { store.iteratorWithMultiValues(colFamilyName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index 14aa43d3234f7..71984a380f9b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -172,6 +172,20 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta innerStore.prefixScanWithMultiValues(prefixKey, colFamilyName) } + override def scan( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + innerStore.scan(startKey, endKey, colFamilyName) + } + + override def scanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + innerStore.scanWithMultiValues(startKey, endKey, colFamilyName) + } + override def iteratorWithMultiValues( colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { innerStore.iteratorWithMultiValues(colFamilyName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala index a9540a4ad623e..0e001cfedf58c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala @@ -566,6 +566,231 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession } } + Seq("unsaferow", "avro").foreach { encoding => + test(s"scan with postfix encoder: bounded range (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "postfix", + useMultipleValuesPerKey = true, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + // Insert entries for key1 at various timestamps + diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => + val keyRow = keyAndTimestampToRow("key1", 1, ts) + store.putList(keyRow, Array(valueToRow(idx * 10), valueToRow(idx * 10 + 1))) + } + + // Insert entries for key2 to verify isolation + store.putList(keyAndTimestampToRow("key2", 1, 500L), + Array(valueToRow(999))) + + // Scan key1 in range [0, 1000] (inclusive via endKey = 1001) + val prefixKey = keyToRow("key1", 1) + val startKey = keyAndTimestampToRow("key1", 1, 0L) + val endKey = keyAndTimestampToRow("key1", 1, 1001L) + val iter = store.scanWithMultiValues(Some(startKey), endKey) + + val results = iter.map { pair => + (pair.key.getLong(2), pair.value.getInt(0)) + }.toList + iter.close() + + // Timestamps in [0, 1000]: 0, 1, 2, 5, 6, 8, 9, 32, 35, 64, 90, 931 + val expectedTimestamps = diverseTimestamps.filter(ts => ts >= 0 && ts <= 1000).sorted + val resultTimestamps = results.map(_._1).distinct + assert(resultTimestamps === expectedTimestamps) + assert(results.length === expectedTimestamps.length * 2) // 2 values per timestamp + } finally { + store.abort() + } + } + } + + test(s"scan with postfix encoder: full range falls back correctly (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "postfix", + useMultipleValuesPerKey = true, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + val timestamps = Seq(100L, 200L, 300L) + timestamps.foreach { ts => + store.putList(keyAndTimestampToRow("key1", 1, ts), + Array(valueToRow(ts.toInt))) + } + + // Scan with Some(startKey) covering full range + val startKey = keyAndTimestampToRow("key1", 1, Long.MinValue) + val endKey = keyAndTimestampToRow("key1", 1, Long.MaxValue) + val iter = store.scanWithMultiValues(Some(startKey), endKey) + val results = iter.map(_.key.getLong(2)).toList + iter.close() + + assert(results === timestamps) + } finally { + store.abort() + } + } + } + + test(s"scan with postfix encoder: empty range (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "postfix", + useMultipleValuesPerKey = true, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + store.putList(keyAndTimestampToRow("key1", 1, 1000L), + Array(valueToRow(100))) + store.putList(keyAndTimestampToRow("key1", 1, 2000L), + Array(valueToRow(200))) + + // Scan range that contains no entries + val startKey = keyAndTimestampToRow("key1", 1, 1500L) + val endKey = keyAndTimestampToRow("key1", 1, 1600L) + val iter = store.scanWithMultiValues(Some(startKey), endKey) + assert(!iter.hasNext) + iter.close() + } finally { + store.abort() + } + } + } + + test(s"scan with prefix encoder: None startKey scans from beginning " + + s"(encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "prefix", + useMultipleValuesPerKey = true, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + // Insert entries at various timestamps + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.merge(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) + } + + // Scan from beginning (None) up to 301 (exclusive), covering [..300] + val endKey = keyAndTimestampToRow("key1", 1, 301L) + val iter = store.scanWithMultiValues(None, endKey) + val results = iter.map(_.key.getLong(2)).toList + iter.close() + + assert(results === Seq(100L, 200L, 300L)) + } finally { + store.abort() + } + } + } + + test(s"scan with prefix encoder: boundary safety check (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "prefix", + useMultipleValuesPerKey = true, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + val timestamps = Seq(100L, 200L, 300L) + timestamps.foreach { ts => + store.merge(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) + } + store.merge(keyAndTimestampToRow("key2", 2, 150L), valueToRow(150)) + + // Scan with endKey at timestamp 201 with dummyKey - should include + // everything up to timestamp 200 regardless of join key + val endKey = keyAndTimestampToRow("key1", 1, 201L) + val iter = store.scanWithMultiValues(None, endKey) + val results = iter.map { pair => + (pair.key.getString(0), pair.key.getLong(2)) + }.toList + iter.close() + + // Should include key1@100, key2@150, key1@200 + assert(results.length === 3) + assert(results.map(_._2) === Seq(100L, 150L, 200L)) + } finally { + store.abort() + } + } + } + + test(s"scan single-value variant (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "postfix", + useMultipleValuesPerKey = false, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.put(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) + } + + val startKey = keyAndTimestampToRow("key1", 1, 200L) + val endKey = keyAndTimestampToRow("key1", 1, 401L) + val iter = store.scan(Some(startKey), endKey) + val results = iter.map { pair => + (pair.key.getLong(2), pair.value.getInt(0)) + }.toList + iter.close() + + assert(results.map(_._1) === Seq(200L, 300L, 400L)) + assert(results.map(_._2) === Seq(200, 300, 400)) + } finally { + store.abort() + } + } + } + + test(s"scan with diverse timestamps and bounded range (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "postfix", + useMultipleValuesPerKey = false, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => + store.put(keyAndTimestampToRow("key1", 1, ts), valueToRow(idx)) + } + + // Scan negative range only: [-300, 0) + val startKey = keyAndTimestampToRow("key1", 1, -300L) + val endKey = keyAndTimestampToRow("key1", 1, 0L) + val iter = store.scan(Some(startKey), endKey) + val results = iter.map(_.key.getLong(2)).toList + iter.close() + + val expected = diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted + assert(results === expected) + } finally { + store.abort() + } + } + } + } + // Helper methods to create test data private val keyProjection = UnsafeProjection.create(keySchema) private val keyAndTimestampProjection = UnsafeProjection.create( From 07e241d082dcf2d55a28945f7c8cee1829d9e0eb Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 3 Apr 2026 16:32:04 +0900 Subject: [PATCH 02/10] fix --- .../stateful/join/SymmetricHashJoinStateManager.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index b0b019bac5860..71f29d7d69623 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -660,7 +660,7 @@ class SymmetricHashJoinStateManagerV4( private val iter = if (minTs == Long.MinValue && maxTs == Long.MaxValue) { stateStore.prefixScanWithMultiValues(key, colFamilyName) } else { - val startKeyRow = createKeyRow(key, minTs) + val startKeyRow = createKeyRow(key, minTs).copy() val endKeyRow = createKeyRow(key, maxTs + 1) stateStore.scanWithMultiValues(Some(startKeyRow), endKeyRow, colFamilyName) } @@ -774,8 +774,9 @@ class SymmetricHashJoinStateManagerV4( } private lazy val dummyKeyRow: UnsafeRow = { + val defaultValues = keySchema.fields.map(f => Literal.default(f.dataType).eval()) val projection = UnsafeProjection.create(keySchema) - projection(new GenericInternalRow(keySchema.length)).copy() + projection(new GenericInternalRow(defaultValues)).copy() } case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int) From ef84c49ef84850b8a4170f2cabd6a3801e98794f Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 4 Apr 2026 08:26:26 +0900 Subject: [PATCH 03/10] change the scan's endKey to option --- .../join/SymmetricHashJoinStateManager.scala | 4 +-- .../execution/streaming/state/RocksDB.scala | 25 ++++++++++++------- .../state/RocksDBStateStoreProvider.scala | 8 +++--- .../streaming/state/StateStore.scala | 14 ++++++----- ...sDBStateStoreCheckpointFormatV2Suite.scala | 4 +-- ...cksDBTimestampEncoderOperationsSuite.scala | 14 +++++------ 6 files changed, 39 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index 71f29d7d69623..bcb341b519caf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -662,7 +662,7 @@ class SymmetricHashJoinStateManagerV4( } else { val startKeyRow = createKeyRow(key, minTs).copy() val endKeyRow = createKeyRow(key, maxTs + 1) - stateStore.scanWithMultiValues(Some(startKeyRow), endKeyRow, colFamilyName) + stateStore.scanWithMultiValues(Some(startKeyRow), Some(endKeyRow), colFamilyName) } private var currentTs = -1L @@ -784,7 +784,7 @@ class SymmetricHashJoinStateManagerV4( // NOTE: This assumes we consume the whole iterator to trigger completion. def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = { val endKeyRow = createKeyRow(dummyKeyRow, endTimestamp + 1) - val evictIterator = stateStore.scanWithMultiValues(None, endKeyRow, colFamilyName) + val evictIterator = stateStore.scanWithMultiValues(None, Some(endKeyRow), colFamilyName) new NextIterator[EvictedKeysResult]() { var currentKeyRow: UnsafeRow = null var currentEventTime: Long = -1L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index ec879a44f213e..4ad3a662b4d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -1667,20 +1667,27 @@ class RocksDB( * * @param startKey None to seek to the beginning of the column family, * or Some(key) to seek to the given start position (inclusive). - * @param endKey The exclusive upper bound for the scan (encoded key bytes). + * @param endKey None to scan to the end of the column family, + * or Some(key) as the exclusive upper bound for the scan (encoded key bytes). * @param cfName The column family name. * @return An iterator of ByteArrayPairs in the given range. */ def scan( startKey: Option[Array[Byte]], - endKey: Array[Byte], + endKey: Option[Array[Byte]], cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): NextIterator[ByteArrayPair] = { updateMemoryUsageIfNeeded() - val updatedEndKey = if (useColumnFamilies) { - encodeStateRowWithPrefix(endKey, cfName) - } else { - endKey + val upperBoundBytes: Option[Array[Byte]] = endKey match { + case Some(key) => + Some(if (useColumnFamilies) encodeStateRowWithPrefix(key, cfName) else key) + case None => + if (useColumnFamilies) { + val cfPrefix = encodeStateRowWithPrefix(Array.emptyByteArray, cfName) + RocksDB.prefixUpperBound(cfPrefix) + } else { + None + } } val seekTarget = startKey match { @@ -1691,9 +1698,9 @@ class RocksDB( else null } - val upperBoundSlice = new Slice(updatedEndKey) + val upperBoundSlice = upperBoundBytes.map(new Slice(_)) val scanReadOptions = new ReadOptions() - scanReadOptions.setIterateUpperBound(upperBoundSlice) + upperBoundSlice.foreach(scanReadOptions.setIterateUpperBound) val iter = db.newIterator(scanReadOptions) if (seekTarget != null) { @@ -1705,7 +1712,7 @@ class RocksDB( def closeResources(): Unit = { iter.close() scanReadOptions.close() - upperBoundSlice.close() + upperBoundSlice.foreach(_.close()) } Option(TaskContext.get()).foreach { tc => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 7485d679d2333..7629ffe7d4325 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -551,14 +551,14 @@ private[sql] class RocksDBStateStoreProvider override def scan( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { validateAndTransitionState(UPDATE) verifyColFamilyOperations("scan", colFamilyName) val kvEncoder = keyValueEncoderMap.get(colFamilyName) val encodedStartKey = startKey.map(kvEncoder._1.encodeKey) - val encodedEndKey = kvEncoder._1.encodeKey(endKey) + val encodedEndKey = endKey.map(kvEncoder._1.encodeKey) val rowPair = new UnsafeRowPair() val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName) @@ -573,7 +573,7 @@ private[sql] class RocksDBStateStoreProvider override def scanWithMultiValues( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { validateAndTransitionState(UPDATE) verifyColFamilyOperations("scanWithMultiValues", colFamilyName) @@ -585,7 +585,7 @@ private[sql] class RocksDBStateStoreProvider " which supports multiple values for a single key") val encodedStartKey = startKey.map(kvEncoder._1.encodeKey) - val encodedEndKey = kvEncoder._1.encodeKey(endKey) + val encodedEndKey = endKey.map(kvEncoder._1.encodeKey) val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName) val rowPair = new UnsafeRowPair() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 89b540ee1d178..9c4e2f06587a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -188,7 +188,8 @@ trait ReadStateStore { * * @param startKey None to scan from the beginning of the column family, * or Some(key) to seek to the given start position (inclusive). - * @param endKey The exclusive upper bound for the scan. + * @param endKey None to scan to the end of the column family, + * or Some(key) as the exclusive upper bound for the scan. * @param colFamilyName The column family name. * * Callers must ensure the column family's key encoder produces lexicographically ordered @@ -197,7 +198,7 @@ trait ReadStateStore { */ def scan( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) : StateStoreIterator[UnsafeRowPair] = { throw StateStoreErrors.unsupportedOperationException("scan", "") @@ -208,7 +209,8 @@ trait ReadStateStore { * * @param startKey None to scan from the beginning of the column family, * or Some(key) to seek to the given start position (inclusive). - * @param endKey The exclusive upper bound for the scan. + * @param endKey None to scan to the end of the column family, + * or Some(key) as the exclusive upper bound for the scan. * @param colFamilyName The column family name. * * Callers must ensure the column family's key encoder produces lexicographically ordered @@ -220,7 +222,7 @@ trait ReadStateStore { */ def scanWithMultiValues( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) : StateStoreIterator[UnsafeRowPair] = { throw StateStoreErrors.unsupportedOperationException("scanWithMultiValues", "") @@ -456,14 +458,14 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore { override def scan( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { store.scan(startKey, endKey, colFamilyName) } override def scanWithMultiValues( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { store.scanWithMultiValues(startKey, endKey, colFamilyName) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index 71984a380f9b4..dca9677a6f7bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -174,14 +174,14 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta override def scan( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { innerStore.scan(startKey, endKey, colFamilyName) } override def scanWithMultiValues( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { innerStore.scanWithMultiValues(startKey, endKey, colFamilyName) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala index 0e001cfedf58c..392c71eee432b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala @@ -591,7 +591,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession val prefixKey = keyToRow("key1", 1) val startKey = keyAndTimestampToRow("key1", 1, 0L) val endKey = keyAndTimestampToRow("key1", 1, 1001L) - val iter = store.scanWithMultiValues(Some(startKey), endKey) + val iter = store.scanWithMultiValues(Some(startKey), Some(endKey)) val results = iter.map { pair => (pair.key.getLong(2), pair.value.getInt(0)) @@ -628,7 +628,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // Scan with Some(startKey) covering full range val startKey = keyAndTimestampToRow("key1", 1, Long.MinValue) val endKey = keyAndTimestampToRow("key1", 1, Long.MaxValue) - val iter = store.scanWithMultiValues(Some(startKey), endKey) + val iter = store.scanWithMultiValues(Some(startKey), Some(endKey)) val results = iter.map(_.key.getLong(2)).toList iter.close() @@ -657,7 +657,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // Scan range that contains no entries val startKey = keyAndTimestampToRow("key1", 1, 1500L) val endKey = keyAndTimestampToRow("key1", 1, 1600L) - val iter = store.scanWithMultiValues(Some(startKey), endKey) + val iter = store.scanWithMultiValues(Some(startKey), Some(endKey)) assert(!iter.hasNext) iter.close() } finally { @@ -685,7 +685,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // Scan from beginning (None) up to 301 (exclusive), covering [..300] val endKey = keyAndTimestampToRow("key1", 1, 301L) - val iter = store.scanWithMultiValues(None, endKey) + val iter = store.scanWithMultiValues(None, Some(endKey)) val results = iter.map(_.key.getLong(2)).toList iter.close() @@ -715,7 +715,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // Scan with endKey at timestamp 201 with dummyKey - should include // everything up to timestamp 200 regardless of join key val endKey = keyAndTimestampToRow("key1", 1, 201L) - val iter = store.scanWithMultiValues(None, endKey) + val iter = store.scanWithMultiValues(None, Some(endKey)) val results = iter.map { pair => (pair.key.getString(0), pair.key.getLong(2)) }.toList @@ -747,7 +747,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession val startKey = keyAndTimestampToRow("key1", 1, 200L) val endKey = keyAndTimestampToRow("key1", 1, 401L) - val iter = store.scan(Some(startKey), endKey) + val iter = store.scan(Some(startKey), Some(endKey)) val results = iter.map { pair => (pair.key.getLong(2), pair.value.getInt(0)) }.toList @@ -778,7 +778,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // Scan negative range only: [-300, 0) val startKey = keyAndTimestampToRow("key1", 1, -300L) val endKey = keyAndTimestampToRow("key1", 1, 0L) - val iter = store.scan(Some(startKey), endKey) + val iter = store.scan(Some(startKey), Some(endKey)) val results = iter.map(_.key.getLong(2)).toList iter.close() From 158dff8d6dce01323fe5e4304454055155127de4 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 6 Apr 2026 14:43:24 +0900 Subject: [PATCH 04/10] Roll back the usage on stream-stream join, add test for range scan --- .../join/SymmetricHashJoinStateManager.scala | 48 ++- .../state/RocksDBStateStoreSuite.scala | 333 ++++++++++++++++++ 2 files changed, 356 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index bcb341b519caf..fc2a69312fe79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -27,7 +27,7 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX, STATE_STORE_ID} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GenericInternalRow, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, WatermarkSupport} @@ -647,25 +647,17 @@ class SymmetricHashJoinStateManagerV4( /** * Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp. - * - * When a bounded range is provided, leverages RocksDB's native seek and upper bound via - * [[StateStore.scanWithMultiValues]] to avoid reading entries outside the range. - * Falls back to [[StateStore.prefixScanWithMultiValues]] when the full range is requested. + * Skips entries before minTs and stops iterating past maxTs (timestamps are sorted). */ def getValuesInRange( key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = { val reusableGetValuesResult = new GetValuesResult() new NextIterator[GetValuesResult] { - private val iter = if (minTs == Long.MinValue && maxTs == Long.MaxValue) { - stateStore.prefixScanWithMultiValues(key, colFamilyName) - } else { - val startKeyRow = createKeyRow(key, minTs).copy() - val endKeyRow = createKeyRow(key, maxTs + 1) - stateStore.scanWithMultiValues(Some(startKeyRow), Some(endKeyRow), colFamilyName) - } + private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName) private var currentTs = -1L + private var pastUpperBound = false private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]() private def flushAccumulated(): GetValuesResult = { @@ -683,13 +675,18 @@ class SymmetricHashJoinStateManagerV4( @tailrec override protected def getNext(): GetValuesResult = { - if (!iter.hasNext) { + if (pastUpperBound || !iter.hasNext) { flushAccumulated() } else { val unsafeRowPair = iter.next() val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key) - if (currentTs == -1L || currentTs == ts) { + if (ts > maxTs) { + pastUpperBound = true + getNext() + } else if (ts < minTs) { + getNext() + } else if (currentTs == -1L || currentTs == ts) { currentTs = ts valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value) getNext() @@ -773,18 +770,11 @@ class SymmetricHashJoinStateManagerV4( stateStore.remove(createKeyRow(key, timestamp), colFamilyName) } - private lazy val dummyKeyRow: UnsafeRow = { - val defaultValues = keySchema.fields.map(f => Literal.default(f.dataType).eval()) - val projection = UnsafeProjection.create(keySchema) - projection(new GenericInternalRow(defaultValues)).copy() - } - case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int) // NOTE: This assumes we consume the whole iterator to trigger completion. def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = { - val endKeyRow = createKeyRow(dummyKeyRow, endTimestamp + 1) - val evictIterator = stateStore.scanWithMultiValues(None, Some(endKeyRow), colFamilyName) + val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName) new NextIterator[EvictedKeysResult]() { var currentKeyRow: UnsafeRow = null var currentEventTime: Long = -1L @@ -799,19 +789,25 @@ class SymmetricHashJoinStateManagerV4( val ts = TimestampKeyStateEncoder.extractTimestamp(kv.key) if (keyRow == currentKeyRow && ts == currentEventTime) { + // new value with same (key, ts) count += 1 } else if (ts > endTimestamp) { - // Safety check for boundary edge case: a small number of entries at exactly - // endTimestamp + 1 may leak through the upper bound because the encoded end key - // includes a join key suffix. + // we found the timestamp beyond the range - we shouldn't continue further isBeyondUpperBound = true + + // We don't need to construct the last (key, ts) into EvictedKeysResult - the code + // after loop will handle that if there is leftover. That said, we do not reset the + // current (key, ts) info here. } else if (currentKeyRow == null && currentEventTime == -1L) { + // first value to process currentKeyRow = keyRow.copy() currentEventTime = ts count = 1 } else { + // construct the last (key, ts) into EvictedKeysResult ret = EvictedKeysResult(currentKeyRow, currentEventTime, count) + // register the next (key, ts) to process currentKeyRow = keyRow.copy() currentEventTime = ts count = 1 @@ -821,8 +817,10 @@ class SymmetricHashJoinStateManagerV4( if (ret != null) { ret } else if (count > 0) { + // there is a final leftover (key, ts) to return ret = EvictedKeysResult(currentKeyRow, currentEventTime, count) + // we shouldn't continue further currentKeyRow = null currentEventTime = -1L count = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 3e4b4b7320f53..386e283ccc67f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -1629,6 +1629,339 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } + private val diverseTimestamps = Seq(931L, 8000L, 452300L, 4200L, -1L, 90L, 1L, 2L, 8L, + -230L, -14569L, -92L, -7434253L, 35L, 6L, 9L, -323L, 5L, + -32L, -64L, -256L, 64L, 32L, 1024L, 4096L, 0L) + + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - scan bounded range", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) + } + + val startKey = dataToKeyRowWithRangeScan(200L, "a") + val endKey = dataToKeyRowWithRangeScan(401L, "a") + val iter = store.scan(Some(startKey), Some(endKey), cfName) + val results = iter.map { pair => + (pair.key.getLong(0), pair.value.getInt(0)) + }.toList + iter.close() + + assert(results.map(_._1) === Seq(200L, 300L, 400L)) + assert(results.map(_._2) === Seq(200, 300, 400)) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scan with None startKey scans from beginning", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) + } + + val endKey = dataToKeyRowWithRangeScan(301L, "a") + val iter = store.scan(None, Some(endKey), cfName) + val results = iter.map(_.key.getLong(0)).toList + iter.close() + + assert(results === Seq(100L, 200L, 300L)) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scan with None endKey scans to end", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) + } + + val startKey = dataToKeyRowWithRangeScan(300L, "a") + val iter = store.scan(Some(startKey), None, cfName) + val results = iter.map(_.key.getLong(0)).toList + iter.close() + + assert(results === Seq(300L, 400L, 500L)) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - scan empty range", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + store.put(dataToKeyRowWithRangeScan(100L, "a"), dataToValueRow(100), cfName) + store.put(dataToKeyRowWithRangeScan(500L, "a"), dataToValueRow(500), cfName) + + val startKey = dataToKeyRowWithRangeScan(200L, "a") + val endKey = dataToKeyRowWithRangeScan(300L, "a") + val iter = store.scan(Some(startKey), Some(endKey), cfName) + assert(!iter.hasNext) + iter.close() + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scan with diverse timestamps bounded range", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => + store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(idx), cfName) + } + + // Scan negative range: [-300, 0) + val startKey = dataToKeyRowWithRangeScan(-300L, "a") + val endKey = dataToKeyRowWithRangeScan(0L, "a") + val iter = store.scan(Some(startKey), Some(endKey), cfName) + val results = iter.map(_.key.getLong(0)).toList + iter.close() + + val expected = diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted + assert(results === expected) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scan with multiple key2 values within same key1 range", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + Seq("a", "b", "c").foreach { key2 => + Seq(100L, 200L, 300L).foreach { ts => + store.put(dataToKeyRowWithRangeScan(ts, key2), dataToValueRow(ts.toInt), cfName) + } + } + + val startKey = dataToKeyRowWithRangeScan(100L, "a") + val endKey = dataToKeyRowWithRangeScan(201L, "a") + val iter = store.scan(Some(startKey), Some(endKey), cfName) + val results = iter.map { pair => + (pair.key.getLong(0), pair.key.getUTF8String(1).toString) + }.toList + iter.close() + + assert(results.map(_._1).distinct.sorted === Seq(100L, 200L)) + assert(results.length === 6) // 3 key2 values x 2 key1 values + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scanWithMultiValues bounded range", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + // Multiple values per key requires column families + if (colFamiliesEnabled) { + tryWithProviderResource(newStoreProvider( + StateStoreId(newDir(), Random.nextInt(), 0), + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + keySchema = keySchemaWithRangeScan, + useColumnFamilies = colFamiliesEnabled, + useMultipleValuesPerKey = true)) { provider => + val store = provider.getStore(0) + try { + val cfName = "testColFamily" + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + useMultipleValuesPerKey = true) + + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.putList(dataToKeyRowWithRangeScan(ts, "a"), + Array(dataToValueRow(ts.toInt), dataToValueRow(ts.toInt + 1)), cfName) + } + + val startKey = dataToKeyRowWithRangeScan(200L, "a") + val endKey = dataToKeyRowWithRangeScan(401L, "a") + val iter = store.scanWithMultiValues(Some(startKey), Some(endKey), cfName) + val results = iter.map { pair => + (pair.key.getLong(0), pair.value.getInt(0)) + }.toList + iter.close() + + val resultTimestamps = results.map(_._1).distinct + assert(resultTimestamps === Seq(200L, 300L, 400L)) + assert(results.length === 6) // 3 timestamps x 2 values each + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scanWithMultiValues with None startKey", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + if (colFamiliesEnabled) { + tryWithProviderResource(newStoreProvider( + StateStoreId(newDir(), Random.nextInt(), 0), + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + keySchema = keySchemaWithRangeScan, + useColumnFamilies = colFamiliesEnabled, + useMultipleValuesPerKey = true)) { provider => + val store = provider.getStore(0) + try { + val cfName = "testColFamily" + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + useMultipleValuesPerKey = true) + + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.merge(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) + } + + val endKey = dataToKeyRowWithRangeScan(301L, "a") + val iter = store.scanWithMultiValues(None, Some(endKey), cfName) + val results = iter.map(_.key.getLong(0)).toList + iter.close() + + assert(results === Seq(100L, 200L, 300L)) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scanWithMultiValues with diverse timestamps", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + if (colFamiliesEnabled) { + tryWithProviderResource(newStoreProvider( + StateStoreId(newDir(), Random.nextInt(), 0), + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + keySchema = keySchemaWithRangeScan, + useColumnFamilies = colFamiliesEnabled, + useMultipleValuesPerKey = true)) { provider => + val store = provider.getStore(0) + try { + val cfName = "testColFamily" + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + useMultipleValuesPerKey = true) + + diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => + store.putList(dataToKeyRowWithRangeScan(ts, "a"), + Array(dataToValueRow(idx * 10), dataToValueRow(idx * 10 + 1)), cfName) + } + + // Scan [0, 1000] (inclusive via endKey = 1001) + val startKey = dataToKeyRowWithRangeScan(0L, "a") + val endKey = dataToKeyRowWithRangeScan(1001L, "a") + val iter = store.scanWithMultiValues(Some(startKey), Some(endKey), cfName) + val results = iter.map { pair => + (pair.key.getLong(0), pair.value.getInt(0)) + }.toList + iter.close() + + val expectedTimestamps = diverseTimestamps.filter(ts => ts >= 0 && ts <= 1000).sorted + val resultTimestamps = results.map(_._1).distinct + assert(resultTimestamps === expectedTimestamps) + assert(results.length === expectedTimestamps.length * 2) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + } + testWithColumnFamiliesAndEncodingTypes( "rocksdb key and value schema encoders for column families", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => From a494b318a1942c9f3a64316fbfdebd20bdce1fd6 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 6 Apr 2026 20:43:40 +0900 Subject: [PATCH 05/10] Reflect self-review comments --- .../state/RocksDBStateStoreSuite.scala | 293 ++++-------------- ...cksDBTimestampEncoderOperationsSuite.scala | 210 ++++--------- 2 files changed, 124 insertions(+), 379 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 386e283ccc67f..7e373a263ab74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -1633,7 +1633,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid -230L, -14569L, -92L, -7434253L, 35L, 6L, 9L, -323L, 5L, -32L, -64L, -256L, 64L, 32L, 1024L, 4096L, 0L) - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - scan bounded range", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, @@ -1648,151 +1648,50 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) } - val timestamps = Seq(100L, 200L, 300L, 400L, 500L) - timestamps.foreach { ts => + diverseTimestamps.foreach { ts => store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) } - val startKey = dataToKeyRowWithRangeScan(200L, "a") - val endKey = dataToKeyRowWithRangeScan(401L, "a") - val iter = store.scan(Some(startKey), Some(endKey), cfName) - val results = iter.map { pair => + // Bounded positive range [0, 100) + val boundedIter = store.scan( + Some(dataToKeyRowWithRangeScan(0L, "a")), + Some(dataToKeyRowWithRangeScan(100L, "a")), cfName) + val boundedResults = boundedIter.map { pair => (pair.key.getLong(0), pair.value.getInt(0)) }.toList - iter.close() - - assert(results.map(_._1) === Seq(200L, 300L, 400L)) - assert(results.map(_._2) === Seq(200, 300, 400)) - } finally { - if (!store.hasCommitted) store.abort() - } - } - } - - testWithColumnFamiliesAndEncodingTypes( - "rocksdb range scan - scan with None startKey scans from beginning", - TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => - - tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - colFamiliesEnabled)) { provider => - val store = provider.getStore(0) - try { - val cfName = if (colFamiliesEnabled) "testColFamily" else "default" - if (colFamiliesEnabled) { - store.createColFamilyIfAbsent(cfName, - keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) - } - - val timestamps = Seq(100L, 200L, 300L, 400L, 500L) - timestamps.foreach { ts => - store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) - } - - val endKey = dataToKeyRowWithRangeScan(301L, "a") - val iter = store.scan(None, Some(endKey), cfName) - val results = iter.map(_.key.getLong(0)).toList - iter.close() - - assert(results === Seq(100L, 200L, 300L)) - } finally { - if (!store.hasCommitted) store.abort() - } - } - } - - testWithColumnFamiliesAndEncodingTypes( - "rocksdb range scan - scan with None endKey scans to end", - TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => - - tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - colFamiliesEnabled)) { provider => - val store = provider.getStore(0) - try { - val cfName = if (colFamiliesEnabled) "testColFamily" else "default" - if (colFamiliesEnabled) { - store.createColFamilyIfAbsent(cfName, - keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) - } - - val timestamps = Seq(100L, 200L, 300L, 400L, 500L) - timestamps.foreach { ts => - store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) - } - - val startKey = dataToKeyRowWithRangeScan(300L, "a") - val iter = store.scan(Some(startKey), None, cfName) - val results = iter.map(_.key.getLong(0)).toList - iter.close() - - assert(results === Seq(300L, 400L, 500L)) - } finally { - if (!store.hasCommitted) store.abort() - } - } - } - - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - scan empty range", - TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => - - tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - colFamiliesEnabled)) { provider => - val store = provider.getStore(0) - try { - val cfName = if (colFamiliesEnabled) "testColFamily" else "default" - if (colFamiliesEnabled) { - store.createColFamilyIfAbsent(cfName, - keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) - } - - store.put(dataToKeyRowWithRangeScan(100L, "a"), dataToValueRow(100), cfName) - store.put(dataToKeyRowWithRangeScan(500L, "a"), dataToValueRow(500), cfName) - - val startKey = dataToKeyRowWithRangeScan(200L, "a") - val endKey = dataToKeyRowWithRangeScan(300L, "a") - val iter = store.scan(Some(startKey), Some(endKey), cfName) - assert(!iter.hasNext) - iter.close() - } finally { - if (!store.hasCommitted) store.abort() - } - } - } - - testWithColumnFamiliesAndEncodingTypes( - "rocksdb range scan - scan with diverse timestamps bounded range", - TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => - - tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - colFamiliesEnabled)) { provider => - val store = provider.getStore(0) - try { - val cfName = if (colFamiliesEnabled) "testColFamily" else "default" - if (colFamiliesEnabled) { - store.createColFamilyIfAbsent(cfName, - keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) - } - - diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => - store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(idx), cfName) - } - - // Scan negative range: [-300, 0) - val startKey = dataToKeyRowWithRangeScan(-300L, "a") - val endKey = dataToKeyRowWithRangeScan(0L, "a") - val iter = store.scan(Some(startKey), Some(endKey), cfName) - val results = iter.map(_.key.getLong(0)).toList - iter.close() - - val expected = diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted - assert(results === expected) + boundedIter.close() + val expectedBoundedTs = diverseTimestamps.filter(ts => ts >= 0 && ts < 100).sorted + assert(boundedResults.map(_._1) === expectedBoundedTs) + assert(boundedResults.map(_._2) === expectedBoundedTs.map(_.toInt)) + + // None startKey scans from beginning to 0 + val noneStartIter = store.scan( + None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val noneStartResults = noneStartIter.map(_.key.getLong(0)).toList + noneStartIter.close() + assert(noneStartResults === diverseTimestamps.filter(_ < 0).sorted) + + // None endKey scans from 1000 to end + val noneEndIter = store.scan( + Some(dataToKeyRowWithRangeScan(1000L, "a")), None, cfName) + val noneEndResults = noneEndIter.map(_.key.getLong(0)).toList + noneEndIter.close() + assert(noneEndResults === diverseTimestamps.filter(_ >= 1000).sorted) + + // Empty range [10, 31) - no entries between 9 and 32 + val emptyIter = store.scan( + Some(dataToKeyRowWithRangeScan(10L, "a")), + Some(dataToKeyRowWithRangeScan(31L, "a")), cfName) + assert(!emptyIter.hasNext) + emptyIter.close() + + // Bounded negative range [-300, 0) + val negIter = store.scan( + Some(dataToKeyRowWithRangeScan(-300L, "a")), + Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val negResults = negIter.map(_.key.getLong(0)).toList + negIter.close() + assert(negResults === diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted) } finally { if (!store.hasCommitted) store.abort() } @@ -1838,88 +1737,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } testWithColumnFamiliesAndEncodingTypes( - "rocksdb range scan - scanWithMultiValues bounded range", - TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => - - // Multiple values per key requires column families - if (colFamiliesEnabled) { - tryWithProviderResource(newStoreProvider( - StateStoreId(newDir(), Random.nextInt(), 0), - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - keySchema = keySchemaWithRangeScan, - useColumnFamilies = colFamiliesEnabled, - useMultipleValuesPerKey = true)) { provider => - val store = provider.getStore(0) - try { - val cfName = "testColFamily" - store.createColFamilyIfAbsent(cfName, - keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - useMultipleValuesPerKey = true) - - val timestamps = Seq(100L, 200L, 300L, 400L, 500L) - timestamps.foreach { ts => - store.putList(dataToKeyRowWithRangeScan(ts, "a"), - Array(dataToValueRow(ts.toInt), dataToValueRow(ts.toInt + 1)), cfName) - } - - val startKey = dataToKeyRowWithRangeScan(200L, "a") - val endKey = dataToKeyRowWithRangeScan(401L, "a") - val iter = store.scanWithMultiValues(Some(startKey), Some(endKey), cfName) - val results = iter.map { pair => - (pair.key.getLong(0), pair.value.getInt(0)) - }.toList - iter.close() - - val resultTimestamps = results.map(_._1).distinct - assert(resultTimestamps === Seq(200L, 300L, 400L)) - assert(results.length === 6) // 3 timestamps x 2 values each - } finally { - if (!store.hasCommitted) store.abort() - } - } - } - } - - testWithColumnFamiliesAndEncodingTypes( - "rocksdb range scan - scanWithMultiValues with None startKey", - TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => - - if (colFamiliesEnabled) { - tryWithProviderResource(newStoreProvider( - StateStoreId(newDir(), Random.nextInt(), 0), - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - keySchema = keySchemaWithRangeScan, - useColumnFamilies = colFamiliesEnabled, - useMultipleValuesPerKey = true)) { provider => - val store = provider.getStore(0) - try { - val cfName = "testColFamily" - store.createColFamilyIfAbsent(cfName, - keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - useMultipleValuesPerKey = true) - - val timestamps = Seq(100L, 200L, 300L, 400L, 500L) - timestamps.foreach { ts => - store.merge(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) - } - - val endKey = dataToKeyRowWithRangeScan(301L, "a") - val iter = store.scanWithMultiValues(None, Some(endKey), cfName) - val results = iter.map(_.key.getLong(0)).toList - iter.close() - - assert(results === Seq(100L, 200L, 300L)) - } finally { - if (!store.hasCommitted) store.abort() - } - } - } - } - - testWithColumnFamiliesAndEncodingTypes( - "rocksdb range scan - scanWithMultiValues with diverse timestamps", + "rocksdb range scan - scanWithMultiValues", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => if (colFamiliesEnabled) { @@ -1942,19 +1760,30 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid Array(dataToValueRow(idx * 10), dataToValueRow(idx * 10 + 1)), cfName) } - // Scan [0, 1000] (inclusive via endKey = 1001) - val startKey = dataToKeyRowWithRangeScan(0L, "a") - val endKey = dataToKeyRowWithRangeScan(1001L, "a") - val iter = store.scanWithMultiValues(Some(startKey), Some(endKey), cfName) - val results = iter.map { pair => + // Bounded range [0, 1001) + val boundedIter = store.scanWithMultiValues( + Some(dataToKeyRowWithRangeScan(0L, "a")), + Some(dataToKeyRowWithRangeScan(1001L, "a")), cfName) + val boundedResults = boundedIter.map { pair => (pair.key.getLong(0), pair.value.getInt(0)) }.toList - iter.close() + boundedIter.close() val expectedTimestamps = diverseTimestamps.filter(ts => ts >= 0 && ts <= 1000).sorted - val resultTimestamps = results.map(_._1).distinct - assert(resultTimestamps === expectedTimestamps) - assert(results.length === expectedTimestamps.length * 2) + assert(boundedResults.map(_._1).distinct === expectedTimestamps) + val expectedValues = diverseTimestamps.zipWithIndex + .filter { case (ts, _) => ts >= 0 && ts <= 1000 } + .sortBy(_._1) + .flatMap { case (_, idx) => Seq(idx * 10, idx * 10 + 1) } + assert(boundedResults.map(_._2) === expectedValues) + + // None startKey scans from beginning to 0 + val noneStartIter = store.scanWithMultiValues( + None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val noneStartResults = noneStartIter.map(_.key.getLong(0)).toList + noneStartIter.close() + + assert(noneStartResults.distinct === diverseTimestamps.filter(_ < 0).sorted) } finally { if (!store.hasCommitted) store.abort() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala index 392c71eee432b..580d6e4fd12c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala @@ -567,7 +567,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession } Seq("unsaferow", "avro").foreach { encoding => - test(s"scan with postfix encoder: bounded range (encoding = $encoding)") { + test(s"scan with postfix encoder (encoding = $encoding)") { tryWithProviderResource( newStoreProviderWithTimestampEncoder( encoderType = "postfix", @@ -577,97 +577,54 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession val store = provider.getStore(0) try { - // Insert entries for key1 at various timestamps diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => val keyRow = keyAndTimestampToRow("key1", 1, ts) store.putList(keyRow, Array(valueToRow(idx * 10), valueToRow(idx * 10 + 1))) } - // Insert entries for key2 to verify isolation + // key2 entry to verify prefix isolation store.putList(keyAndTimestampToRow("key2", 1, 500L), Array(valueToRow(999))) - // Scan key1 in range [0, 1000] (inclusive via endKey = 1001) - val prefixKey = keyToRow("key1", 1) - val startKey = keyAndTimestampToRow("key1", 1, 0L) - val endKey = keyAndTimestampToRow("key1", 1, 1001L) - val iter = store.scanWithMultiValues(Some(startKey), Some(endKey)) - - val results = iter.map { pair => + // Bounded range [0, 1001) + val boundedIter = store.scanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, 0L)), + Some(keyAndTimestampToRow("key1", 1, 1001L))) + val boundedResults = boundedIter.map { pair => (pair.key.getLong(2), pair.value.getInt(0)) }.toList - iter.close() + boundedIter.close() - // Timestamps in [0, 1000]: 0, 1, 2, 5, 6, 8, 9, 32, 35, 64, 90, 931 val expectedTimestamps = diverseTimestamps.filter(ts => ts >= 0 && ts <= 1000).sorted - val resultTimestamps = results.map(_._1).distinct - assert(resultTimestamps === expectedTimestamps) - assert(results.length === expectedTimestamps.length * 2) // 2 values per timestamp - } finally { - store.abort() - } - } - } - - test(s"scan with postfix encoder: full range falls back correctly (encoding = $encoding)") { - tryWithProviderResource( - newStoreProviderWithTimestampEncoder( - encoderType = "postfix", - useMultipleValuesPerKey = true, - dataEncoding = encoding) - ) { provider => - val store = provider.getStore(0) - - try { - val timestamps = Seq(100L, 200L, 300L) - timestamps.foreach { ts => - store.putList(keyAndTimestampToRow("key1", 1, ts), - Array(valueToRow(ts.toInt))) - } - - // Scan with Some(startKey) covering full range - val startKey = keyAndTimestampToRow("key1", 1, Long.MinValue) - val endKey = keyAndTimestampToRow("key1", 1, Long.MaxValue) - val iter = store.scanWithMultiValues(Some(startKey), Some(endKey)) - val results = iter.map(_.key.getLong(2)).toList - iter.close() - - assert(results === timestamps) + assert(boundedResults.map(_._1).distinct === expectedTimestamps) + val expectedValues = diverseTimestamps.zipWithIndex + .filter { case (ts, _) => ts >= 0 && ts <= 1000 } + .sortBy(_._1) + .flatMap { case (_, idx) => Seq(idx * 10, idx * 10 + 1) } + assert(boundedResults.map(_._2) === expectedValues) + + // Full range [MinValue, MaxValue) + val fullIter = store.scanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, Long.MinValue)), + Some(keyAndTimestampToRow("key1", 1, Long.MaxValue))) + val fullResults = fullIter.map(_.key.getLong(2)).toList + fullIter.close() + + assert(fullResults.distinct === diverseTimestamps.sorted) + + // Empty range [10, 31) - no diverseTimestamps entries between 9 and 32 + val emptyIter = store.scanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, 10L)), + Some(keyAndTimestampToRow("key1", 1, 31L))) + assert(!emptyIter.hasNext) + emptyIter.close() } finally { store.abort() } } } - test(s"scan with postfix encoder: empty range (encoding = $encoding)") { - tryWithProviderResource( - newStoreProviderWithTimestampEncoder( - encoderType = "postfix", - useMultipleValuesPerKey = true, - dataEncoding = encoding) - ) { provider => - val store = provider.getStore(0) - - try { - store.putList(keyAndTimestampToRow("key1", 1, 1000L), - Array(valueToRow(100))) - store.putList(keyAndTimestampToRow("key1", 1, 2000L), - Array(valueToRow(200))) - - // Scan range that contains no entries - val startKey = keyAndTimestampToRow("key1", 1, 1500L) - val endKey = keyAndTimestampToRow("key1", 1, 1600L) - val iter = store.scanWithMultiValues(Some(startKey), Some(endKey)) - assert(!iter.hasNext) - iter.close() - } finally { - store.abort() - } - } - } - - test(s"scan with prefix encoder: None startKey scans from beginning " + - s"(encoding = $encoding)") { + test(s"scan with prefix encoder (encoding = $encoding)") { tryWithProviderResource( newStoreProviderWithTimestampEncoder( encoderType = "prefix", @@ -677,53 +634,30 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession val store = provider.getStore(0) try { - // Insert entries at various timestamps val timestamps = Seq(100L, 200L, 300L, 400L, 500L) timestamps.foreach { ts => store.merge(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) } + store.merge(keyAndTimestampToRow("key2", 2, 150L), valueToRow(150)) - // Scan from beginning (None) up to 301 (exclusive), covering [..300] - val endKey = keyAndTimestampToRow("key1", 1, 301L) - val iter = store.scanWithMultiValues(None, Some(endKey)) - val results = iter.map(_.key.getLong(2)).toList - iter.close() - - assert(results === Seq(100L, 200L, 300L)) - } finally { - store.abort() - } - } - } - - test(s"scan with prefix encoder: boundary safety check (encoding = $encoding)") { - tryWithProviderResource( - newStoreProviderWithTimestampEncoder( - encoderType = "prefix", - useMultipleValuesPerKey = true, - dataEncoding = encoding) - ) { provider => - val store = provider.getStore(0) + // None startKey scans from beginning up to 301 (exclusive) + val iter1 = store.scanWithMultiValues(None, + Some(keyAndTimestampToRow("key1", 1, 301L))) + val results1 = iter1.map(_.key.getLong(2)).toList + iter1.close() - try { - val timestamps = Seq(100L, 200L, 300L) - timestamps.foreach { ts => - store.merge(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) - } - store.merge(keyAndTimestampToRow("key2", 2, 150L), valueToRow(150)) + assert(results1 === Seq(100L, 150L, 200L, 300L)) - // Scan with endKey at timestamp 201 with dummyKey - should include - // everything up to timestamp 200 regardless of join key - val endKey = keyAndTimestampToRow("key1", 1, 201L) - val iter = store.scanWithMultiValues(None, Some(endKey)) - val results = iter.map { pair => + // Boundary safety: endKey at 201, includes everything up to 200 + // regardless of join key + val iter2 = store.scanWithMultiValues(None, + Some(keyAndTimestampToRow("key1", 1, 201L))) + val results2 = iter2.map { pair => (pair.key.getString(0), pair.key.getLong(2)) }.toList - iter.close() + iter2.close() - // Should include key1@100, key2@150, key1@200 - assert(results.length === 3) - assert(results.map(_._2) === Seq(100L, 150L, 200L)) + assert(results2.map(_._2) === Seq(100L, 150L, 200L)) } finally { store.abort() } @@ -740,50 +674,32 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession val store = provider.getStore(0) try { - val timestamps = Seq(100L, 200L, 300L, 400L, 500L) - timestamps.foreach { ts => + diverseTimestamps.foreach { ts => store.put(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) } - val startKey = keyAndTimestampToRow("key1", 1, 200L) - val endKey = keyAndTimestampToRow("key1", 1, 401L) - val iter = store.scan(Some(startKey), Some(endKey)) - val results = iter.map { pair => + // Bounded positive range [0, 100) + val posIter = store.scan( + Some(keyAndTimestampToRow("key1", 1, 0L)), + Some(keyAndTimestampToRow("key1", 1, 100L))) + val posResults = posIter.map { pair => (pair.key.getLong(2), pair.value.getInt(0)) }.toList - iter.close() - - assert(results.map(_._1) === Seq(200L, 300L, 400L)) - assert(results.map(_._2) === Seq(200, 300, 400)) - } finally { - store.abort() - } - } - } - - test(s"scan with diverse timestamps and bounded range (encoding = $encoding)") { - tryWithProviderResource( - newStoreProviderWithTimestampEncoder( - encoderType = "postfix", - useMultipleValuesPerKey = false, - dataEncoding = encoding) - ) { provider => - val store = provider.getStore(0) + posIter.close() - try { - diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => - store.put(keyAndTimestampToRow("key1", 1, ts), valueToRow(idx)) - } + val expectedPosTs = diverseTimestamps.filter(ts => ts >= 0 && ts < 100).sorted + assert(posResults.map(_._1) === expectedPosTs) + assert(posResults.map(_._2) === expectedPosTs.map(_.toInt)) - // Scan negative range only: [-300, 0) - val startKey = keyAndTimestampToRow("key1", 1, -300L) - val endKey = keyAndTimestampToRow("key1", 1, 0L) - val iter = store.scan(Some(startKey), Some(endKey)) - val results = iter.map(_.key.getLong(2)).toList - iter.close() + // Bounded negative range [-300, 0) + val negIter = store.scan( + Some(keyAndTimestampToRow("key1", 1, -300L)), + Some(keyAndTimestampToRow("key1", 1, 0L))) + val negResults = negIter.map(_.key.getLong(2)).toList + negIter.close() - val expected = diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted - assert(results === expected) + val expectedNegTs = diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted + assert(negResults === expectedNegTs) } finally { store.abort() } From 0aa4813239e63d5797221437cccf8f709c9a7070 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 7 Apr 2026 06:22:03 +0900 Subject: [PATCH 06/10] Second round of self review comments --- .../state/RocksDBStateStoreSuite.scala | 57 +++++++++++- ...cksDBTimestampEncoderOperationsSuite.scala | 88 +++++++++---------- 2 files changed, 94 insertions(+), 51 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 7e373a263ab74..e6dd70c55d737 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -1633,7 +1633,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid -230L, -14569L, -92L, -7434253L, 35L, 6L, 9L, -323L, 5L, -32L, -64L, -256L, 64L, 32L, 1024L, 4096L, 0L) - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - scan", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, @@ -1664,6 +1664,18 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid assert(boundedResults.map(_._1) === expectedBoundedTs) assert(boundedResults.map(_._2) === expectedBoundedTs.map(_.toInt)) + // Exact bound: startKey is inclusive, endKey is exclusive. + // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. + // Scan [9, 90) should include 9 but exclude 90. + val exactIter = store.scan( + Some(dataToKeyRowWithRangeScan(9L, "a")), + Some(dataToKeyRowWithRangeScan(90L, "a")), cfName) + val exactResults = exactIter.map(_.key.getLong(0)).toList + exactIter.close() + assert(exactResults === diverseTimestamps.filter(ts => ts >= 9 && ts < 90).sorted) + assert(exactResults.contains(9L)) + assert(!exactResults.contains(90L)) + // None startKey scans from beginning to 0 val noneStartIter = store.scan( None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) @@ -1728,8 +1740,10 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid }.toList iter.close() - assert(results.map(_._1).distinct.sorted === Seq(100L, 200L)) - assert(results.length === 6) // 3 key2 values x 2 key1 values + val expectedResults = Seq( + (100L, "a"), (100L, "b"), (100L, "c"), + (200L, "a"), (200L, "b"), (200L, "c")) + assert(results === expectedResults) } finally { if (!store.hasCommitted) store.abort() } @@ -1777,13 +1791,48 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid .flatMap { case (_, idx) => Seq(idx * 10, idx * 10 + 1) } assert(boundedResults.map(_._2) === expectedValues) + // Exact bound: startKey is inclusive, endKey is exclusive. + // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. + val exactIter = store.scanWithMultiValues( + Some(dataToKeyRowWithRangeScan(9L, "a")), + Some(dataToKeyRowWithRangeScan(90L, "a")), cfName) + val exactResults = exactIter.map(_.key.getLong(0)).toList + exactIter.close() + val exactResultsDistinct = exactResults.distinct + assert(exactResultsDistinct === diverseTimestamps + .filter(ts => ts >= 9 && ts < 90).sorted) + assert(exactResultsDistinct.contains(9L)) + assert(!exactResultsDistinct.contains(90L)) + // None startKey scans from beginning to 0 val noneStartIter = store.scanWithMultiValues( None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) val noneStartResults = noneStartIter.map(_.key.getLong(0)).toList noneStartIter.close() - assert(noneStartResults.distinct === diverseTimestamps.filter(_ < 0).sorted) + + // None endKey scans from 1000 to end + val noneEndIter = store.scanWithMultiValues( + Some(dataToKeyRowWithRangeScan(1000L, "a")), None, cfName) + val noneEndResults = noneEndIter.map(_.key.getLong(0)).toList + noneEndIter.close() + assert(noneEndResults.distinct === diverseTimestamps.filter(_ >= 1000).sorted) + + // Empty range [10, 31) - no entries between 9 and 32 + val emptyIter = store.scanWithMultiValues( + Some(dataToKeyRowWithRangeScan(10L, "a")), + Some(dataToKeyRowWithRangeScan(31L, "a")), cfName) + assert(!emptyIter.hasNext) + emptyIter.close() + + // Bounded negative range [-300, 0) + val negIter = store.scanWithMultiValues( + Some(dataToKeyRowWithRangeScan(-300L, "a")), + Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val negResults = negIter.map(_.key.getLong(0)).toList + negIter.close() + assert(negResults.distinct === diverseTimestamps + .filter(ts => ts >= -300 && ts < 0).sorted) } finally { if (!store.hasCommitted) store.abort() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala index 580d6e4fd12c4..414a9a39872a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala @@ -591,17 +591,36 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession Some(keyAndTimestampToRow("key1", 1, 0L)), Some(keyAndTimestampToRow("key1", 1, 1001L))) val boundedResults = boundedIter.map { pair => - (pair.key.getLong(2), pair.value.getInt(0)) + (pair.key.getString(0), pair.key.getLong(2), pair.value.getInt(0)) }.toList boundedIter.close() val expectedTimestamps = diverseTimestamps.filter(ts => ts >= 0 && ts <= 1000).sorted - assert(boundedResults.map(_._1).distinct === expectedTimestamps) + assert(boundedResults.map(_._2).distinct === expectedTimestamps) val expectedValues = diverseTimestamps.zipWithIndex .filter { case (ts, _) => ts >= 0 && ts <= 1000 } .sortBy(_._1) .flatMap { case (_, idx) => Seq(idx * 10, idx * 10 + 1) } - assert(boundedResults.map(_._2) === expectedValues) + assert(boundedResults.map(_._3) === expectedValues) + assert(boundedResults.forall(_._1 == "key1")) + + // Exact bound: startKey is inclusive, endKey is exclusive. + // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. + val exactIter = store.scanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, 9L)), + Some(keyAndTimestampToRow("key1", 1, 90L))) + val exactResults = exactIter.map(_.key.getLong(2)).toList + exactIter.close() + val exactResultsDistinct = exactResults.distinct + assert(exactResultsDistinct === diverseTimestamps + .filter(ts => ts >= 9 && ts < 90).sorted) + assert(exactResultsDistinct.contains(9L)) + assert(!exactResultsDistinct.contains(90L)) + + // Postfix timestamp encoder places the timestamp after the key prefix. + // With different key prefixes, None in startKey or endKey would scan across + // key boundaries, which is not meaningful for postfix encoding. Hence we only + // test bounded ranges with explicit keys here. // Full range [MinValue, MaxValue) val fullIter = store.scanWithMultiValues( @@ -612,6 +631,15 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession assert(fullResults.distinct === diverseTimestamps.sorted) + // Bounded negative range [-300, 0) + val negIter = store.scanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, -300L)), + Some(keyAndTimestampToRow("key1", 1, 0L))) + val negResults = negIter.map(_.key.getLong(2)).toList + negIter.close() + assert(negResults.distinct === diverseTimestamps + .filter(ts => ts >= -300 && ts < 0).sorted) + // Empty range [10, 31) - no diverseTimestamps entries between 9 and 32 val emptyIter = store.scanWithMultiValues( Some(keyAndTimestampToRow("key1", 1, 10L)), @@ -624,6 +652,9 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession } } + // Sanity test for prefix encoder scan. Full scan coverage is in RocksDBStateStoreSuite's + // "rocksdb range scan - scan" and "rocksdb range scan - scanWithMultiValues" tests. + // This test verifies the timestamp prefix encoder integration works correctly. test(s"scan with prefix encoder (encoding = $encoding)") { tryWithProviderResource( newStoreProviderWithTimestampEncoder( @@ -643,10 +674,13 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // None startKey scans from beginning up to 301 (exclusive) val iter1 = store.scanWithMultiValues(None, Some(keyAndTimestampToRow("key1", 1, 301L))) - val results1 = iter1.map(_.key.getLong(2)).toList + val results1 = iter1.map { pair => + (pair.key.getString(0), pair.key.getLong(2)) + }.toList iter1.close() - assert(results1 === Seq(100L, 150L, 200L, 300L)) + assert(results1 === Seq( + ("key1", 100L), ("key2", 150L), ("key1", 200L), ("key1", 300L))) // Boundary safety: endKey at 201, includes everything up to 200 // regardless of join key @@ -657,54 +691,14 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession }.toList iter2.close() - assert(results2.map(_._2) === Seq(100L, 150L, 200L)) + assert(results2 === Seq( + ("key1", 100L), ("key2", 150L), ("key1", 200L))) } finally { store.abort() } } } - test(s"scan single-value variant (encoding = $encoding)") { - tryWithProviderResource( - newStoreProviderWithTimestampEncoder( - encoderType = "postfix", - useMultipleValuesPerKey = false, - dataEncoding = encoding) - ) { provider => - val store = provider.getStore(0) - - try { - diverseTimestamps.foreach { ts => - store.put(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) - } - - // Bounded positive range [0, 100) - val posIter = store.scan( - Some(keyAndTimestampToRow("key1", 1, 0L)), - Some(keyAndTimestampToRow("key1", 1, 100L))) - val posResults = posIter.map { pair => - (pair.key.getLong(2), pair.value.getInt(0)) - }.toList - posIter.close() - - val expectedPosTs = diverseTimestamps.filter(ts => ts >= 0 && ts < 100).sorted - assert(posResults.map(_._1) === expectedPosTs) - assert(posResults.map(_._2) === expectedPosTs.map(_.toInt)) - - // Bounded negative range [-300, 0) - val negIter = store.scan( - Some(keyAndTimestampToRow("key1", 1, -300L)), - Some(keyAndTimestampToRow("key1", 1, 0L))) - val negResults = negIter.map(_.key.getLong(2)).toList - negIter.close() - - val expectedNegTs = diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted - assert(negResults === expectedNegTs) - } finally { - store.abort() - } - } - } } // Helper methods to create test data From 710466b9d4c8169c8cffb6338430656e82e5696c Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 7 Apr 2026 15:15:07 +0900 Subject: [PATCH 07/10] Change the API name to rangeScan --- .../streaming/state/RocksDBStateEncoder.scala | 11 +++++++ .../state/RocksDBStateStoreProvider.scala | 13 +++++--- .../streaming/state/StateStore.scala | 16 +++++----- ...sDBStateStoreCheckpointFormatV2Suite.scala | 8 ++--- .../state/RocksDBStateStoreSuite.scala | 30 +++++++++---------- ...cksDBTimestampEncoderOperationsSuite.scala | 20 ++++++------- 6 files changed, 57 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 5acf7cdc9b975..4d9c77348b493 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -46,6 +46,7 @@ import org.apache.spark.unsafe.Platform sealed trait RocksDBKeyStateEncoder { def supportPrefixKeyScan: Boolean def supportsDeleteRange: Boolean + def supportsRangeScan: Boolean def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] def encodeKey(row: UnsafeRow): Array[Byte] def decodeKey(keyBytes: Array[Byte]): UnsafeRow @@ -1500,6 +1501,8 @@ class PrefixKeyScanStateEncoder( override def supportPrefixKeyScan: Boolean = true override def supportsDeleteRange: Boolean = false + + override def supportsRangeScan: Boolean = false } /** @@ -1699,6 +1702,8 @@ class RangeKeyScanStateEncoder( override def supportPrefixKeyScan: Boolean = true override def supportsDeleteRange: Boolean = true + + override def supportsRangeScan: Boolean = true } /** @@ -1731,6 +1736,8 @@ class NoPrefixKeyStateEncoder( override def supportsDeleteRange: Boolean = false + override def supportsRangeScan: Boolean = false + override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { throw new IllegalStateException("This encoder doesn't support prefix key!") } @@ -1884,6 +1891,8 @@ class TimestampAsPrefixKeyStateEncoder( // TODO: [SPARK-55491] Revisit this to support delete range if needed. override def supportsDeleteRange: Boolean = false + + override def supportsRangeScan: Boolean = true } /** @@ -1932,6 +1941,8 @@ class TimestampAsPostfixKeyStateEncoder( } override def supportsDeleteRange: Boolean = false + + override def supportsRangeScan: Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 7629ffe7d4325..e1490c71bc69b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -549,14 +549,17 @@ private[sql] class RocksDBStateStoreProvider new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) } - override def scan( + override def rangeScan( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { validateAndTransitionState(UPDATE) - verifyColFamilyOperations("scan", colFamilyName) + verifyColFamilyOperations("rangeScan", colFamilyName) val kvEncoder = keyValueEncoderMap.get(colFamilyName) + require(kvEncoder._1.supportsRangeScan, + "Range scan requires an encoder that supports range scanning!") + val encodedStartKey = startKey.map(kvEncoder._1.encodeKey) val encodedEndKey = endKey.map(kvEncoder._1.encodeKey) @@ -571,14 +574,16 @@ private[sql] class RocksDBStateStoreProvider new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) } - override def scanWithMultiValues( + override def rangeScanWithMultiValues( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { validateAndTransitionState(UPDATE) - verifyColFamilyOperations("scanWithMultiValues", colFamilyName) + verifyColFamilyOperations("rangeScanWithMultiValues", colFamilyName) val kvEncoder = keyValueEncoderMap.get(colFamilyName) + require(kvEncoder._1.supportsRangeScan, + "Range scan requires an encoder that supports range scanning!") verify( kvEncoder._2.supportsMultipleValuesPerKey, "Multi-value iterator operation requires an encoder" + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 9c4e2f06587a2..e3601f1ef2246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -196,12 +196,12 @@ trait ReadStateStore { * bytes for the scan range to be meaningful (e.g., timestamp-based encoders or * RangeKeyScanStateEncoder). */ - def scan( + def rangeScan( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) : StateStoreIterator[UnsafeRowPair] = { - throw StateStoreErrors.unsupportedOperationException("scan", "") + throw StateStoreErrors.unsupportedOperationException("rangeScan", "") } /** @@ -220,12 +220,12 @@ trait ReadStateStore { * It is expected to throw exception if Spark calls this method without setting * multipleValuesPerKey as true for the column family. */ - def scanWithMultiValues( + def rangeScanWithMultiValues( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) : StateStoreIterator[UnsafeRowPair] = { - throw StateStoreErrors.unsupportedOperationException("scanWithMultiValues", "") + throw StateStoreErrors.unsupportedOperationException("rangeScanWithMultiValues", "") } /** Return an iterator containing all the key-value pairs in the StateStore. */ @@ -456,18 +456,18 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore { store.prefixScanWithMultiValues(prefixKey, colFamilyName) } - override def scan( + override def rangeScan( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { - store.scan(startKey, endKey, colFamilyName) + store.rangeScan(startKey, endKey, colFamilyName) } - override def scanWithMultiValues( + override def rangeScanWithMultiValues( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { - store.scanWithMultiValues(startKey, endKey, colFamilyName) + store.rangeScanWithMultiValues(startKey, endKey, colFamilyName) } override def iteratorWithMultiValues( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index dca9677a6f7bc..fe09506023ccd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -172,18 +172,18 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta innerStore.prefixScanWithMultiValues(prefixKey, colFamilyName) } - override def scan( + override def rangeScan( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { - innerStore.scan(startKey, endKey, colFamilyName) + innerStore.rangeScan(startKey, endKey, colFamilyName) } - override def scanWithMultiValues( + override def rangeScanWithMultiValues( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { - innerStore.scanWithMultiValues(startKey, endKey, colFamilyName) + innerStore.rangeScanWithMultiValues(startKey, endKey, colFamilyName) } override def iteratorWithMultiValues( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index e6dd70c55d737..0c300192dd898 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -1633,7 +1633,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid -230L, -14569L, -92L, -7434253L, 35L, 6L, 9L, -323L, 5L, -32L, -64L, -256L, 64L, 32L, 1024L, 4096L, 0L) - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - scan", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - rangeScan", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, @@ -1653,7 +1653,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } // Bounded positive range [0, 100) - val boundedIter = store.scan( + val boundedIter = store.rangeScan( Some(dataToKeyRowWithRangeScan(0L, "a")), Some(dataToKeyRowWithRangeScan(100L, "a")), cfName) val boundedResults = boundedIter.map { pair => @@ -1667,7 +1667,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid // Exact bound: startKey is inclusive, endKey is exclusive. // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. // Scan [9, 90) should include 9 but exclude 90. - val exactIter = store.scan( + val exactIter = store.rangeScan( Some(dataToKeyRowWithRangeScan(9L, "a")), Some(dataToKeyRowWithRangeScan(90L, "a")), cfName) val exactResults = exactIter.map(_.key.getLong(0)).toList @@ -1677,28 +1677,28 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid assert(!exactResults.contains(90L)) // None startKey scans from beginning to 0 - val noneStartIter = store.scan( + val noneStartIter = store.rangeScan( None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) val noneStartResults = noneStartIter.map(_.key.getLong(0)).toList noneStartIter.close() assert(noneStartResults === diverseTimestamps.filter(_ < 0).sorted) // None endKey scans from 1000 to end - val noneEndIter = store.scan( + val noneEndIter = store.rangeScan( Some(dataToKeyRowWithRangeScan(1000L, "a")), None, cfName) val noneEndResults = noneEndIter.map(_.key.getLong(0)).toList noneEndIter.close() assert(noneEndResults === diverseTimestamps.filter(_ >= 1000).sorted) // Empty range [10, 31) - no entries between 9 and 32 - val emptyIter = store.scan( + val emptyIter = store.rangeScan( Some(dataToKeyRowWithRangeScan(10L, "a")), Some(dataToKeyRowWithRangeScan(31L, "a")), cfName) assert(!emptyIter.hasNext) emptyIter.close() // Bounded negative range [-300, 0) - val negIter = store.scan( + val negIter = store.rangeScan( Some(dataToKeyRowWithRangeScan(-300L, "a")), Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) val negResults = negIter.map(_.key.getLong(0)).toList @@ -1734,7 +1734,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid val startKey = dataToKeyRowWithRangeScan(100L, "a") val endKey = dataToKeyRowWithRangeScan(201L, "a") - val iter = store.scan(Some(startKey), Some(endKey), cfName) + val iter = store.rangeScan(Some(startKey), Some(endKey), cfName) val results = iter.map { pair => (pair.key.getLong(0), pair.key.getUTF8String(1).toString) }.toList @@ -1751,7 +1751,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } testWithColumnFamiliesAndEncodingTypes( - "rocksdb range scan - scanWithMultiValues", + "rocksdb range scan - rangeScanWithMultiValues", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => if (colFamiliesEnabled) { @@ -1775,7 +1775,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } // Bounded range [0, 1001) - val boundedIter = store.scanWithMultiValues( + val boundedIter = store.rangeScanWithMultiValues( Some(dataToKeyRowWithRangeScan(0L, "a")), Some(dataToKeyRowWithRangeScan(1001L, "a")), cfName) val boundedResults = boundedIter.map { pair => @@ -1793,7 +1793,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid // Exact bound: startKey is inclusive, endKey is exclusive. // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. - val exactIter = store.scanWithMultiValues( + val exactIter = store.rangeScanWithMultiValues( Some(dataToKeyRowWithRangeScan(9L, "a")), Some(dataToKeyRowWithRangeScan(90L, "a")), cfName) val exactResults = exactIter.map(_.key.getLong(0)).toList @@ -1805,28 +1805,28 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid assert(!exactResultsDistinct.contains(90L)) // None startKey scans from beginning to 0 - val noneStartIter = store.scanWithMultiValues( + val noneStartIter = store.rangeScanWithMultiValues( None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) val noneStartResults = noneStartIter.map(_.key.getLong(0)).toList noneStartIter.close() assert(noneStartResults.distinct === diverseTimestamps.filter(_ < 0).sorted) // None endKey scans from 1000 to end - val noneEndIter = store.scanWithMultiValues( + val noneEndIter = store.rangeScanWithMultiValues( Some(dataToKeyRowWithRangeScan(1000L, "a")), None, cfName) val noneEndResults = noneEndIter.map(_.key.getLong(0)).toList noneEndIter.close() assert(noneEndResults.distinct === diverseTimestamps.filter(_ >= 1000).sorted) // Empty range [10, 31) - no entries between 9 and 32 - val emptyIter = store.scanWithMultiValues( + val emptyIter = store.rangeScanWithMultiValues( Some(dataToKeyRowWithRangeScan(10L, "a")), Some(dataToKeyRowWithRangeScan(31L, "a")), cfName) assert(!emptyIter.hasNext) emptyIter.close() // Bounded negative range [-300, 0) - val negIter = store.scanWithMultiValues( + val negIter = store.rangeScanWithMultiValues( Some(dataToKeyRowWithRangeScan(-300L, "a")), Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) val negResults = negIter.map(_.key.getLong(0)).toList diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala index 414a9a39872a0..5fcdfb12ba354 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala @@ -567,7 +567,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession } Seq("unsaferow", "avro").foreach { encoding => - test(s"scan with postfix encoder (encoding = $encoding)") { + test(s"rangeScan with postfix encoder (encoding = $encoding)") { tryWithProviderResource( newStoreProviderWithTimestampEncoder( encoderType = "postfix", @@ -587,7 +587,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession Array(valueToRow(999))) // Bounded range [0, 1001) - val boundedIter = store.scanWithMultiValues( + val boundedIter = store.rangeScanWithMultiValues( Some(keyAndTimestampToRow("key1", 1, 0L)), Some(keyAndTimestampToRow("key1", 1, 1001L))) val boundedResults = boundedIter.map { pair => @@ -606,7 +606,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // Exact bound: startKey is inclusive, endKey is exclusive. // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. - val exactIter = store.scanWithMultiValues( + val exactIter = store.rangeScanWithMultiValues( Some(keyAndTimestampToRow("key1", 1, 9L)), Some(keyAndTimestampToRow("key1", 1, 90L))) val exactResults = exactIter.map(_.key.getLong(2)).toList @@ -623,7 +623,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // test bounded ranges with explicit keys here. // Full range [MinValue, MaxValue) - val fullIter = store.scanWithMultiValues( + val fullIter = store.rangeScanWithMultiValues( Some(keyAndTimestampToRow("key1", 1, Long.MinValue)), Some(keyAndTimestampToRow("key1", 1, Long.MaxValue))) val fullResults = fullIter.map(_.key.getLong(2)).toList @@ -632,7 +632,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession assert(fullResults.distinct === diverseTimestamps.sorted) // Bounded negative range [-300, 0) - val negIter = store.scanWithMultiValues( + val negIter = store.rangeScanWithMultiValues( Some(keyAndTimestampToRow("key1", 1, -300L)), Some(keyAndTimestampToRow("key1", 1, 0L))) val negResults = negIter.map(_.key.getLong(2)).toList @@ -641,7 +641,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession .filter(ts => ts >= -300 && ts < 0).sorted) // Empty range [10, 31) - no diverseTimestamps entries between 9 and 32 - val emptyIter = store.scanWithMultiValues( + val emptyIter = store.rangeScanWithMultiValues( Some(keyAndTimestampToRow("key1", 1, 10L)), Some(keyAndTimestampToRow("key1", 1, 31L))) assert(!emptyIter.hasNext) @@ -653,9 +653,9 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession } // Sanity test for prefix encoder scan. Full scan coverage is in RocksDBStateStoreSuite's - // "rocksdb range scan - scan" and "rocksdb range scan - scanWithMultiValues" tests. + // "rocksdb range scan - rangeScan" and "rocksdb range scan - rangeScanWithMultiValues" tests. // This test verifies the timestamp prefix encoder integration works correctly. - test(s"scan with prefix encoder (encoding = $encoding)") { + test(s"rangeScan with prefix encoder (encoding = $encoding)") { tryWithProviderResource( newStoreProviderWithTimestampEncoder( encoderType = "prefix", @@ -672,7 +672,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession store.merge(keyAndTimestampToRow("key2", 2, 150L), valueToRow(150)) // None startKey scans from beginning up to 301 (exclusive) - val iter1 = store.scanWithMultiValues(None, + val iter1 = store.rangeScanWithMultiValues(None, Some(keyAndTimestampToRow("key1", 1, 301L))) val results1 = iter1.map { pair => (pair.key.getString(0), pair.key.getLong(2)) @@ -684,7 +684,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // Boundary safety: endKey at 201, includes everything up to 200 // regardless of join key - val iter2 = store.scanWithMultiValues(None, + val iter2 = store.rangeScanWithMultiValues(None, Some(keyAndTimestampToRow("key1", 1, 201L))) val results2 = iter2.map { pair => (pair.key.getString(0), pair.key.getLong(2)) From d750b045418b5d87a0a2a6f166ac460b2e9039b4 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 7 Apr 2026 11:01:23 +0900 Subject: [PATCH 08/10] Apply scan to transformWithState operators (TTL and timers) Use bounded scan ranges in transformWithState TTL eviction and timer expiry to narrow the iteration scope: - TTLState.ttlEvictionIterator: use store.scan with startKey from prevBatchTimestampMs+1 and endKey from batchTimestampMs+1 to skip entries already cleaned up in the previous batch. - TimerStateImpl.getExpiredTimers: use store.scan with startKey from prevExpiryTimestampMs+1 and endKey from expiryTimestampMs+1. Processing-time timers use prevBatchTimestampMs; event-time timers use eventTimeWatermarkForLateEvents. Thread prevBatchTimestampMs from IncrementalExecution (via prevOffsetSeqMetadata) through TransformWithStateExec -> StatefulProcessorHandleImpl -> TTLState / TimerStateImpl. Copy UnsafeRow results from encodeTTLRow/UnsafeProjection to avoid the mutable-row-reuse bug where startKey and endKey alias the same internal buffer. --- .../spark/sql/execution/SparkStrategies.scala | 2 + .../TransformWithStateInPySparkExec.scala | 5 ++- .../TransformWithStateExec.scala | 16 ++++++-- .../StatefulProcessorHandleImpl.scala | 24 +++++++---- .../timers/TimerStateImpl.scala | 40 ++++++++++++++++--- .../ttl/ListStateImplWithTTL.scala | 8 +++- .../ttl/MapStateImplWithTTL.scala | 9 ++++- .../transformwithstate/ttl/TTLState.scala | 28 ++++++++++++- .../ttl/ValueStateImplWithTTL.scala | 8 +++- .../runtime/IncrementalExecution.scala | 2 + 10 files changed, 119 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f11dcbd1e7c1e..6c2ce58b884d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -787,6 +787,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { outputAttr, stateInfo = None, batchTimestampMs = None, + prevBatchTimestampMs = None, eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, planLater(child), @@ -815,6 +816,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { func, t.leftAttributes, outputAttrs, outputMode, timeMode, stateInfo = None, batchTimestampMs = None, + prevBatchTimestampMs = None, eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, userFacingDataType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala index 1ceaf6c4bf81f..45f2af5c1dfe8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala @@ -74,6 +74,7 @@ case class TransformWithStateInPySparkExec( timeMode: TimeMode, stateInfo: Option[StatefulOperatorStateInfo], batchTimestampMs: Option[Long], + prevBatchTimestampMs: Option[Long] = None, eventTimeWatermarkForLateEvents: Option[Long], eventTimeWatermarkForEviction: Option[Long], userFacingDataType: TransformWithStateInPySpark.UserFacingDataType.Value, @@ -314,7 +315,8 @@ case class TransformWithStateInPySparkExec( val data = groupAndProject(filteredIter, groupingAttributes, child.output, dedupAttributes) val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId, - groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs, metrics) + groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs, + prevBatchTimestampMs, metrics) val evalType = { if (userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS) { @@ -442,6 +444,7 @@ object TransformWithStateInPySparkExec { Some(System.currentTimeMillis), None, None, + None, userFacingDataType, child, isStreaming = false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala index cc1c3263ad743..b11f6d93a642b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala @@ -67,6 +67,7 @@ case class TransformWithStateExec( outputObjAttr: Attribute, stateInfo: Option[StatefulOperatorStateInfo], batchTimestampMs: Option[Long], + prevBatchTimestampMs: Option[Long] = None, eventTimeWatermarkForLateEvents: Option[Long], eventTimeWatermarkForEviction: Option[Long], child: SparkPlan, @@ -251,7 +252,7 @@ case class TransformWithStateExec( case ProcessingTime => assert(batchTimestampMs.isDefined) val batchTimestamp = batchTimestampMs.get - processorHandle.getExpiredTimers(batchTimestamp) + processorHandle.getExpiredTimers(batchTimestamp, prevBatchTimestampMs) .flatMap { case (keyObj, expiryTimestampMs) => numExpiredTimers += 1 handleTimerRows(keyObj, expiryTimestampMs, processorHandle) @@ -260,7 +261,13 @@ case class TransformWithStateExec( case EventTime => assert(eventTimeWatermarkForEviction.isDefined) val watermark = eventTimeWatermarkForEviction.get - processorHandle.getExpiredTimers(watermark) + // Only use the late-events watermark as the scan lower bound when a previous batch + // actually existed (prevBatchTimestampMs is set). In the very first batch the + // watermark propagation yields Some(0) for late events even though no timers have + // been processed yet, which would incorrectly skip timers registered at timestamp 0. + val prevWatermark = + if (prevBatchTimestampMs.isDefined) eventTimeWatermarkForLateEvents else None + processorHandle.getExpiredTimers(watermark, prevWatermark) .flatMap { case (keyObj, expiryTimestampMs) => numExpiredTimers += 1 handleTimerRows(keyObj, expiryTimestampMs, processorHandle) @@ -493,7 +500,7 @@ case class TransformWithStateExec( CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl( store, getStateInfo.queryRunId, keyEncoder, timeMode, - isStreaming, batchTimestampMs, metrics) + isStreaming, batchTimestampMs, prevBatchTimestampMs, metrics) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) withStatefulProcessorErrorHandling("init") { @@ -509,7 +516,7 @@ case class TransformWithStateExec( initStateIterator: Iterator[InternalRow]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId, - keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics) + keyEncoder, timeMode, isStreaming, batchTimestampMs, prevBatchTimestampMs, metrics) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) withStatefulProcessorErrorHandling("init") { @@ -581,6 +588,7 @@ object TransformWithStateExec { Some(System.currentTimeMillis), None, None, + None, child, isStreaming = false, hasInitialState, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statefulprocessor/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statefulprocessor/StatefulProcessorHandleImpl.scala index dfba0e1f12146..291cc02ea989b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statefulprocessor/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statefulprocessor/StatefulProcessorHandleImpl.scala @@ -114,6 +114,7 @@ class StatefulProcessorHandleImpl( timeMode: TimeMode, isStreaming: Boolean = true, batchTimestampMs: Option[Long] = None, + prevBatchTimestampMs: Option[Long] = None, metrics: Map[String, SQLMetric] = Map.empty) extends StatefulProcessorHandleImplBase(timeMode, keyEncoder) with Logging { import StatefulProcessorHandleState._ @@ -171,13 +172,19 @@ class StatefulProcessorHandleImpl( /** * Function to retrieve all expired registered timers for all grouping keys - * @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function - * will return all timers that have timestamp less than passed threshold + * @param expiryTimestampMs Threshold for expired timestamp in milliseconds (inclusive), + * this function will return all timers that have timestamp + * less than or equal to the passed threshold. + * @param prevExpiryTimestampMs If provided, the lower bound (exclusive) of the scan range. + * Timers at or below this timestamp are assumed to have been + * already processed in the previous batch and will be skipped. * @return - iterator of registered timers for all grouping keys */ - def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = { + def getExpiredTimers( + expiryTimestampMs: Long, + prevExpiryTimestampMs: Option[Long] = None): Iterator[(Any, Long)] = { verifyTimerOperations("get_expired_timers") - timerState.getExpiredTimers(expiryTimestampMs) + timerState.getExpiredTimers(expiryTimestampMs, prevExpiryTimestampMs) } /** @@ -237,7 +244,8 @@ class StatefulProcessorHandleImpl( validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName, - keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics) + keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, + prevBatchTimestampMs, metrics) ttlStates.add(valueStateWithTTL) TWSMetricsUtils.incrementMetric(metrics, "numValueStateWithTTLVars") valueStateWithTTL @@ -286,7 +294,8 @@ class StatefulProcessorHandleImpl( validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, - keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics) + keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, + prevBatchTimestampMs, metrics) TWSMetricsUtils.incrementMetric(metrics, "numListStateWithTTLVars") ttlStates.add(listStateWithTTL) listStateWithTTL @@ -324,7 +333,8 @@ class StatefulProcessorHandleImpl( validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc, - valEncoder, ttlConfig, batchTimestampMs.get, metrics) + valEncoder, ttlConfig, batchTimestampMs.get, + prevBatchTimestampMs, metrics) TWSMetricsUtils.incrementMetric(metrics, "numMapStateWithTTLVars") ttlStates.add(mapStateWithTTL) mapStateWithTTL diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala index 101265fd8d83b..f4a1a06974aa8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala @@ -112,6 +112,27 @@ class TimerStateImpl( schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, Seq(0)), useMultipleValuesPerKey = false, isInternal = true) + private val secIndexProjection = UnsafeProjection.create(keySchemaForSecIndex) + + /** + * Encodes a timestamp into an UnsafeRow key for the secondary index. + * The timestamp is incremented by 1 so that the encoded key serves as an exclusive + * lower / upper bound in range scans. Returns None if tsMs is Long.MaxValue + * (overflow guard). + * + * The returned UnsafeRow is always a fresh copy, safe to hold alongside other + * rows produced by the same projection. + */ + private def encodeTimestampAsKey(tsMs: Long): Option[UnsafeRow] = { + if (tsMs < Long.MaxValue) { + val row = new GenericInternalRow(keySchemaForSecIndex.length) + row.setLong(0, tsMs + 1) + Some(secIndexProjection.apply(row).copy()) + } else { + None + } + } + private def getGroupingKey(cfName: String): Any = { val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption if (keyOption.isEmpty) { @@ -189,15 +210,22 @@ class TimerStateImpl( /** * Function to get all the expired registered timers for all grouping keys. - * Perform a range scan on timestamp and will stop iterating once the key row timestamp equals or + * Perform a range scan on timestamp and will stop iterating once the key row timestamp * exceeds the limit (as timestamp key is increasingly sorted). - * @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function - * will return all timers that have timestamp less than passed threshold. + * @param expiryTimestampMs Threshold for expired timestamp in milliseconds (inclusive), + * this function will return all timers that have timestamp + * less than or equal to the passed threshold. + * @param prevExpiryTimestampMs If provided, the lower bound (exclusive) of the scan range. + * Timers at or below this timestamp are assumed to have been + * already processed in the previous batch and will be skipped. * @return - iterator of all the registered timers for all grouping keys */ - def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = { - // this iter is increasingly sorted on timestamp - val iter = store.iterator(tsToKeyCFName) + def getExpiredTimers( + expiryTimestampMs: Long, + prevExpiryTimestampMs: Option[Long] = None): Iterator[(Any, Long)] = { + val startKey = prevExpiryTimestampMs.flatMap(encodeTimestampAsKey) + val endKey = encodeTimestampAsKey(expiryTimestampMs) + val iter = store.rangeScan(startKey, endKey, tsToKeyCFName) new NextIterator[(Any, Long)] { override protected def getNext(): (Any, Long) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ListStateImplWithTTL.scala index 08f97e38bd086..10ec3a58500af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ListStateImplWithTTL.scala @@ -35,6 +35,10 @@ import org.apache.spark.util.NextIterator * @param valEncoder - Spark SQL encoder for value * @param ttlConfig - TTL configuration for values stored in this state * @param batchTimestampMs - current batch processing timestamp. + * @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive). + * Entries with expiration at or below this timestamp are assumed + * to have been already cleaned up and will be skipped during + * TTL eviction scans. * @param metrics - metrics to be updated as part of stateful processing * @tparam S - data type of object that will be stored */ @@ -45,9 +49,11 @@ class ListStateImplWithTTL[S]( valEncoder: ExpressionEncoder[Any], ttlConfig: TTLConfig, batchTimestampMs: Long, + prevBatchTimestampMs: Option[Long] = None, metrics: Map[String, SQLMetric]) extends OneToManyTTLState( - stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, metrics) with ListState[S] { + stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, + prevBatchTimestampMs, metrics) with ListState[S] { private lazy val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, stateName, hasTtl = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala index f063354bc8c8c..03aa8aaa6ace2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala @@ -36,6 +36,10 @@ import org.apache.spark.util.NextIterator * @param valEncoder - SQL encoder for state variable * @param ttlConfig - the ttl configuration (time to live duration etc.) * @param batchTimestampMs - current batch processing timestamp. + * @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive). + * Entries with expiration at or below this timestamp are assumed + * to have been already cleaned up and will be skipped during + * TTL eviction scans. * @param metrics - metrics to be updated as part of stateful processing * @tparam K - type of key for map state variable * @tparam V - type of value for map state variable @@ -49,10 +53,11 @@ class MapStateImplWithTTL[K, V]( valEncoder: ExpressionEncoder[Any], ttlConfig: TTLConfig, batchTimestampMs: Long, -metrics: Map[String, SQLMetric]) + prevBatchTimestampMs: Option[Long] = None, + metrics: Map[String, SQLMetric]) extends OneToOneTTLState( stateName, store, getCompositeKeySchema(keyExprEnc.schema, userKeyEnc.schema), ttlConfig, - batchTimestampMs, metrics) with MapState[K, V] with Logging { + batchTimestampMs, prevBatchTimestampMs, metrics) with MapState[K, V] with Logging { private val stateTypesEncoder = new CompositeKeyStateEncoder( keyExprEnc, userKeyEnc, valEncoder, stateName, hasTtl = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/TTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/TTLState.scala index 548a47ea75e13..e7f6213b185e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/TTLState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/TTLState.scala @@ -88,6 +88,11 @@ trait TTLState { // an expiration at or before this timestamp must be cleaned up. private[sql] def batchTimestampMs: Long + // The batch timestamp from the previous micro-batch, used to derive the startKey + // for scan-based TTL eviction. Entries at or below prevBatchTimestampMs were already + // cleaned up in the previous batch. + private[sql] def prevBatchTimestampMs: Option[Long] + // The configuration for this run of the streaming query. It may change between runs // (e.g. user sets ttlConfig1, stops their query, updates to ttlConfig2, and then // resumes their query). @@ -105,6 +110,8 @@ trait TTLState { private final val TTL_ENCODER = new TTLEncoder(elementKeySchema) + private final val ELEMENT_KEY_PROJECTION = UnsafeProjection.create(elementKeySchema) + // Empty row used for values private final val TTL_EMPTY_VALUE_ROW = UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) @@ -161,10 +168,25 @@ trait TTLState { // // The schema of the UnsafeRow returned by this iterator is (expirationMs, elementKey). private[sql] def ttlEvictionIterator(): Iterator[UnsafeRow] = { - val ttlIterator = store.iterator(TTL_INDEX) + val dummyElementKey = ELEMENT_KEY_PROJECTION + .apply(new GenericInternalRow(elementKeySchema.length)) + val startKey = prevBatchTimestampMs.flatMap { prevTs => + if (prevTs < Long.MaxValue) { + Some(TTL_ENCODER.encodeTTLRow(prevTs + 1, dummyElementKey).copy()) + } else { + None + } + } + val endKey = if (batchTimestampMs < Long.MaxValue) { + Some(TTL_ENCODER.encodeTTLRow(batchTimestampMs + 1, dummyElementKey).copy()) + } else { + None + } + val ttlIterator = store.rangeScan(startKey, endKey, TTL_INDEX) // Recall that the format is (expirationMs, elementKey) -> TTL_EMPTY_VALUE_ROW, so // kv.value doesn't ever need to be used. + // Safety filter: keep only truly expired entries ttlIterator.takeWhile { kv => val expirationMs = kv.key.getLong(0) StateTTL.isExpired(expirationMs, batchTimestampMs) @@ -223,12 +245,14 @@ abstract class OneToOneTTLState( elementKeySchemaArg: StructType, ttlConfigArg: TTLConfig, batchTimestampMsArg: Long, + prevBatchTimestampMsArg: Option[Long], metricsArg: Map[String, SQLMetric]) extends TTLState { override private[sql] def stateName: String = stateNameArg override private[sql] def store: StateStore = storeArg override private[sql] def elementKeySchema: StructType = elementKeySchemaArg override private[sql] def ttlConfig: TTLConfig = ttlConfigArg override private[sql] def batchTimestampMs: Long = batchTimestampMsArg + override private[sql] def prevBatchTimestampMs: Option[Long] = prevBatchTimestampMsArg override private[sql] def metrics: Map[String, SQLMetric] = metricsArg /** @@ -340,12 +364,14 @@ abstract class OneToManyTTLState( elementKeySchemaArg: StructType, ttlConfigArg: TTLConfig, batchTimestampMsArg: Long, + prevBatchTimestampMsArg: Option[Long], metricsArg: Map[String, SQLMetric]) extends TTLState { override private[sql] def stateName: String = stateNameArg override private[sql] def store: StateStore = storeArg override private[sql] def elementKeySchema: StructType = elementKeySchemaArg override private[sql] def ttlConfig: TTLConfig = ttlConfigArg override private[sql] def batchTimestampMs: Long = batchTimestampMsArg + override private[sql] def prevBatchTimestampMs: Option[Long] = prevBatchTimestampMsArg override private[sql] def metrics: Map[String, SQLMetric] = metricsArg // Schema of the min-expiry index: elementKey -> minExpirationMs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ValueStateImplWithTTL.scala index 587da75993610..1559acf7222cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ValueStateImplWithTTL.scala @@ -33,6 +33,10 @@ import org.apache.spark.sql.streaming.{TTLConfig, ValueState} * @param valEncoder - Spark SQL encoder for value * @param ttlConfig - TTL configuration for values stored in this state * @param batchTimestampMs - current batch processing timestamp. + * @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive). + * Entries with expiration at or below this timestamp are assumed + * to have been already cleaned up and will be skipped during + * TTL eviction scans. * @param metrics - metrics to be updated as part of stateful processing * @tparam S - data type of object that will be stored */ @@ -43,9 +47,11 @@ class ValueStateImplWithTTL[S]( valEncoder: ExpressionEncoder[Any], ttlConfig: TTLConfig, batchTimestampMs: Long, + prevBatchTimestampMs: Option[Long] = None, metrics: Map[String, SQLMetric] = Map.empty) extends OneToOneTTLState( - stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, metrics) with ValueState[S] { + stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, + prevBatchTimestampMs, metrics) with ValueState[S] { private val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, stateName, hasTtl = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala index 169ab6f606dae..0d5d89db9334f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala @@ -384,6 +384,7 @@ class IncrementalExecution( t.copy( stateInfo = Some(nextStatefulOperationStateInfo()), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + prevBatchTimestampMs = prevOffsetSeqMetadata.map(_.batchTimestampMs), eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, hasInitialState = hasInitialState @@ -394,6 +395,7 @@ class IncrementalExecution( t.copy( stateInfo = Some(nextStatefulOperationStateInfo()), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + prevBatchTimestampMs = prevOffsetSeqMetadata.map(_.batchTimestampMs), eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, hasInitialState = hasInitialState From d09017eb7204f7fddf51a16b4886cac44ab3dd0f Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 10 Apr 2026 12:09:16 +0900 Subject: [PATCH 09/10] Reflect review comment --- .../stateful/transformwithstate/ttl/TTLState.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/TTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/TTLState.scala index e7f6213b185e8..6219313b7e027 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/TTLState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/TTLState.scala @@ -162,9 +162,12 @@ trait TTLState { store.iterator(TTL_INDEX).map(kv => toTTLRow(kv.key)) } - // Returns an Iterator over all the keys in the TTL index that have expired. This method - // does not delete the keys from the TTL index; it is the responsibility of the caller - // to do so. + // Returns an Iterator over the keys in the TTL index that have expired. Uses a bounded + // range scan over [prevBatchTimestampMs+1, batchTimestampMs+1) to skip entries that + // were already evicted in previous batches. + // + // This method does not delete the keys from the TTL index; it is the responsibility of + // the caller to do so. // // The schema of the UnsafeRow returned by this iterator is (expirationMs, elementKey). private[sql] def ttlEvictionIterator(): Iterator[UnsafeRow] = { From 3c99f65e1301e0c5248d1738f5ab1b8a7d37a27e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 14 Apr 2026 14:29:49 +0900 Subject: [PATCH 10/10] empty commit to retrigger CI