From 137c25b6bf38a5da7ba0bd2d34449df6a6def96c Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 2 Apr 2026 11:08:25 +0900 Subject: [PATCH 01/12] WIP introduce scan to scope the iteration in State Store --- .../join/SymmetricHashJoinStateManager.scala | 47 ++-- .../execution/streaming/state/RocksDB.scala | 80 +++++++ .../state/RocksDBStateStoreProvider.scala | 57 +++++ .../streaming/state/StateStore.scala | 57 +++++ ...sDBStateStoreCheckpointFormatV2Suite.scala | 14 ++ ...cksDBTimestampEncoderOperationsSuite.scala | 225 ++++++++++++++++++ 6 files changed, 457 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index fc2a69312fe79..b0b019bac5860 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -27,7 +27,7 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX, STATE_STORE_ID} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GenericInternalRow, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, WatermarkSupport} @@ -647,17 +647,25 @@ class SymmetricHashJoinStateManagerV4( /** * Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp. - * Skips entries before minTs and stops iterating past maxTs (timestamps are sorted). + * + * When a bounded range is provided, leverages RocksDB's native seek and upper bound via + * [[StateStore.scanWithMultiValues]] to avoid reading entries outside the range. + * Falls back to [[StateStore.prefixScanWithMultiValues]] when the full range is requested. */ def getValuesInRange( key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = { val reusableGetValuesResult = new GetValuesResult() new NextIterator[GetValuesResult] { - private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName) + private val iter = if (minTs == Long.MinValue && maxTs == Long.MaxValue) { + stateStore.prefixScanWithMultiValues(key, colFamilyName) + } else { + val startKeyRow = createKeyRow(key, minTs) + val endKeyRow = createKeyRow(key, maxTs + 1) + stateStore.scanWithMultiValues(Some(startKeyRow), endKeyRow, colFamilyName) + } private var currentTs = -1L - private var pastUpperBound = false private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]() private def flushAccumulated(): GetValuesResult = { @@ -675,18 +683,13 @@ class SymmetricHashJoinStateManagerV4( @tailrec override protected def getNext(): GetValuesResult = { - if (pastUpperBound || !iter.hasNext) { + if (!iter.hasNext) { flushAccumulated() } else { val unsafeRowPair = iter.next() val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key) - if (ts > maxTs) { - pastUpperBound = true - getNext() - } else if (ts < minTs) { - getNext() - } else if (currentTs == -1L || currentTs == ts) { + if (currentTs == -1L || currentTs == ts) { currentTs = ts valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value) getNext() @@ -770,11 +773,17 @@ class SymmetricHashJoinStateManagerV4( stateStore.remove(createKeyRow(key, timestamp), colFamilyName) } + private lazy val dummyKeyRow: UnsafeRow = { + val projection = UnsafeProjection.create(keySchema) + projection(new GenericInternalRow(keySchema.length)).copy() + } + case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int) // NOTE: This assumes we consume the whole iterator to trigger completion. def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = { - val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName) + val endKeyRow = createKeyRow(dummyKeyRow, endTimestamp + 1) + val evictIterator = stateStore.scanWithMultiValues(None, endKeyRow, colFamilyName) new NextIterator[EvictedKeysResult]() { var currentKeyRow: UnsafeRow = null var currentEventTime: Long = -1L @@ -789,25 +798,19 @@ class SymmetricHashJoinStateManagerV4( val ts = TimestampKeyStateEncoder.extractTimestamp(kv.key) if (keyRow == currentKeyRow && ts == currentEventTime) { - // new value with same (key, ts) count += 1 } else if (ts > endTimestamp) { - // we found the timestamp beyond the range - we shouldn't continue further + // Safety check for boundary edge case: a small number of entries at exactly + // endTimestamp + 1 may leak through the upper bound because the encoded end key + // includes a join key suffix. isBeyondUpperBound = true - - // We don't need to construct the last (key, ts) into EvictedKeysResult - the code - // after loop will handle that if there is leftover. That said, we do not reset the - // current (key, ts) info here. } else if (currentKeyRow == null && currentEventTime == -1L) { - // first value to process currentKeyRow = keyRow.copy() currentEventTime = ts count = 1 } else { - // construct the last (key, ts) into EvictedKeysResult ret = EvictedKeysResult(currentKeyRow, currentEventTime, count) - // register the next (key, ts) to process currentKeyRow = keyRow.copy() currentEventTime = ts count = 1 @@ -817,10 +820,8 @@ class SymmetricHashJoinStateManagerV4( if (ret != null) { ret } else if (count > 0) { - // there is a final leftover (key, ts) to return ret = EvictedKeysResult(currentKeyRow, currentEventTime, count) - // we shouldn't continue further currentKeyRow = null currentEventTime = -1L count = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index a24a76269828f..ec879a44f213e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -1662,6 +1662,86 @@ class RocksDB( } } + /** + * Scan key-value pairs in the range [startKey, endKey). + * + * @param startKey None to seek to the beginning of the column family, + * or Some(key) to seek to the given start position (inclusive). + * @param endKey The exclusive upper bound for the scan (encoded key bytes). + * @param cfName The column family name. + * @return An iterator of ByteArrayPairs in the given range. + */ + def scan( + startKey: Option[Array[Byte]], + endKey: Array[Byte], + cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): NextIterator[ByteArrayPair] = { + updateMemoryUsageIfNeeded() + + val updatedEndKey = if (useColumnFamilies) { + encodeStateRowWithPrefix(endKey, cfName) + } else { + endKey + } + + val seekTarget = startKey match { + case Some(key) => + if (useColumnFamilies) encodeStateRowWithPrefix(key, cfName) else key + case None => + if (useColumnFamilies) encodeStateRowWithPrefix(Array.emptyByteArray, cfName) + else null + } + + val upperBoundSlice = new Slice(updatedEndKey) + val scanReadOptions = new ReadOptions() + scanReadOptions.setIterateUpperBound(upperBoundSlice) + + val iter = db.newIterator(scanReadOptions) + if (seekTarget != null) { + iter.seek(seekTarget) + } else { + iter.seekToFirst() + } + + def closeResources(): Unit = { + iter.close() + scanReadOptions.close() + upperBoundSlice.close() + } + + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => closeResources() } + } + + new NextIterator[ByteArrayPair] { + override protected def getNext(): ByteArrayPair = { + if (iter.isValid) { + val key = if (useColumnFamilies) { + decodeStateRowWithPrefix(iter.key)._1 + } else { + iter.key + } + + val value = if (conf.rowChecksumEnabled) { + KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum( + readVerifier, iter.key, iter.value, delimiterSize) + } else { + iter.value + } + + byteArrayPair.set(key, value) + iter.next() + byteArrayPair + } else { + finished = true + closeResources() + null + } + } + + override protected def close(): Unit = closeResources() + } + } + def release(): Unit = {} /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 8b023a0e9f9fc..7485d679d2333 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -549,6 +549,63 @@ private[sql] class RocksDBStateStoreProvider new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) } + override def scan( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + validateAndTransitionState(UPDATE) + verifyColFamilyOperations("scan", colFamilyName) + + val kvEncoder = keyValueEncoderMap.get(colFamilyName) + val encodedStartKey = startKey.map(kvEncoder._1.encodeKey) + val encodedEndKey = kvEncoder._1.encodeKey(endKey) + + val rowPair = new UnsafeRowPair() + val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName) + val iter = rocksDbIter.map { kv => + rowPair.withRows(kvEncoder._1.decodeKey(kv.key), + kvEncoder._2.decodeValue(kv.value)) + rowPair + } + + new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) + } + + override def scanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + validateAndTransitionState(UPDATE) + verifyColFamilyOperations("scanWithMultiValues", colFamilyName) + + val kvEncoder = keyValueEncoderMap.get(colFamilyName) + verify( + kvEncoder._2.supportsMultipleValuesPerKey, + "Multi-value iterator operation requires an encoder" + + " which supports multiple values for a single key") + + val encodedStartKey = startKey.map(kvEncoder._1.encodeKey) + val encodedEndKey = kvEncoder._1.encodeKey(endKey) + val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName) + + val rowPair = new UnsafeRowPair() + val iter = rocksDbIter.flatMap { kv => + val keyRow = kvEncoder._1.decodeKey(kv.key) + val valueRows = kvEncoder._2.decodeValues(kv.value) + valueRows.iterator.map { valueRow => + rowPair.withRows(keyRow, valueRow) + if (!isValidated && rowPair.value != null && !useColumnFamilies) { + StateStoreProvider.validateStateRowFormat( + rowPair.key, keySchema, rowPair.value, valueSchema, stateStoreId, storeConf) + isValidated = true + } + rowPair + } + } + + new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) + } + var checkpointInfo: Option[StateStoreCheckpointInfo] = None private var storedMetrics: Option[RocksDBMetrics] = None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 6e08c10476ce7..89b540ee1d178 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -183,6 +183,49 @@ trait ReadStateStore { prefixKey: UnsafeRow, colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair] + /** + * Scan key-value pairs in the range [startKey, endKey). + * + * @param startKey None to scan from the beginning of the column family, + * or Some(key) to seek to the given start position (inclusive). + * @param endKey The exclusive upper bound for the scan. + * @param colFamilyName The column family name. + * + * Callers must ensure the column family's key encoder produces lexicographically ordered + * bytes for the scan range to be meaningful (e.g., timestamp-based encoders or + * RangeKeyScanStateEncoder). + */ + def scan( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) + : StateStoreIterator[UnsafeRowPair] = { + throw StateStoreErrors.unsupportedOperationException("scan", "") + } + + /** + * Scan key-value pairs in the range [startKey, endKey), expanding multi-valued entries. + * + * @param startKey None to scan from the beginning of the column family, + * or Some(key) to seek to the given start position (inclusive). + * @param endKey The exclusive upper bound for the scan. + * @param colFamilyName The column family name. + * + * Callers must ensure the column family's key encoder produces lexicographically ordered + * bytes for the scan range to be meaningful (e.g., timestamp-based encoders or + * RangeKeyScanStateEncoder). + * + * It is expected to throw exception if Spark calls this method without setting + * multipleValuesPerKey as true for the column family. + */ + def scanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) + : StateStoreIterator[UnsafeRowPair] = { + throw StateStoreErrors.unsupportedOperationException("scanWithMultiValues", "") + } + /** Return an iterator containing all the key-value pairs in the StateStore. */ def iterator( colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair] @@ -411,6 +454,20 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore { store.prefixScanWithMultiValues(prefixKey, colFamilyName) } + override def scan( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + store.scan(startKey, endKey, colFamilyName) + } + + override def scanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + store.scanWithMultiValues(startKey, endKey, colFamilyName) + } + override def iteratorWithMultiValues( colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { store.iteratorWithMultiValues(colFamilyName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index 14aa43d3234f7..71984a380f9b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -172,6 +172,20 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta innerStore.prefixScanWithMultiValues(prefixKey, colFamilyName) } + override def scan( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + innerStore.scan(startKey, endKey, colFamilyName) + } + + override def scanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + innerStore.scanWithMultiValues(startKey, endKey, colFamilyName) + } + override def iteratorWithMultiValues( colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { innerStore.iteratorWithMultiValues(colFamilyName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala index a9540a4ad623e..0e001cfedf58c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala @@ -566,6 +566,231 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession } } + Seq("unsaferow", "avro").foreach { encoding => + test(s"scan with postfix encoder: bounded range (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "postfix", + useMultipleValuesPerKey = true, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + // Insert entries for key1 at various timestamps + diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => + val keyRow = keyAndTimestampToRow("key1", 1, ts) + store.putList(keyRow, Array(valueToRow(idx * 10), valueToRow(idx * 10 + 1))) + } + + // Insert entries for key2 to verify isolation + store.putList(keyAndTimestampToRow("key2", 1, 500L), + Array(valueToRow(999))) + + // Scan key1 in range [0, 1000] (inclusive via endKey = 1001) + val prefixKey = keyToRow("key1", 1) + val startKey = keyAndTimestampToRow("key1", 1, 0L) + val endKey = keyAndTimestampToRow("key1", 1, 1001L) + val iter = store.scanWithMultiValues(Some(startKey), endKey) + + val results = iter.map { pair => + (pair.key.getLong(2), pair.value.getInt(0)) + }.toList + iter.close() + + // Timestamps in [0, 1000]: 0, 1, 2, 5, 6, 8, 9, 32, 35, 64, 90, 931 + val expectedTimestamps = diverseTimestamps.filter(ts => ts >= 0 && ts <= 1000).sorted + val resultTimestamps = results.map(_._1).distinct + assert(resultTimestamps === expectedTimestamps) + assert(results.length === expectedTimestamps.length * 2) // 2 values per timestamp + } finally { + store.abort() + } + } + } + + test(s"scan with postfix encoder: full range falls back correctly (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "postfix", + useMultipleValuesPerKey = true, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + val timestamps = Seq(100L, 200L, 300L) + timestamps.foreach { ts => + store.putList(keyAndTimestampToRow("key1", 1, ts), + Array(valueToRow(ts.toInt))) + } + + // Scan with Some(startKey) covering full range + val startKey = keyAndTimestampToRow("key1", 1, Long.MinValue) + val endKey = keyAndTimestampToRow("key1", 1, Long.MaxValue) + val iter = store.scanWithMultiValues(Some(startKey), endKey) + val results = iter.map(_.key.getLong(2)).toList + iter.close() + + assert(results === timestamps) + } finally { + store.abort() + } + } + } + + test(s"scan with postfix encoder: empty range (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "postfix", + useMultipleValuesPerKey = true, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + store.putList(keyAndTimestampToRow("key1", 1, 1000L), + Array(valueToRow(100))) + store.putList(keyAndTimestampToRow("key1", 1, 2000L), + Array(valueToRow(200))) + + // Scan range that contains no entries + val startKey = keyAndTimestampToRow("key1", 1, 1500L) + val endKey = keyAndTimestampToRow("key1", 1, 1600L) + val iter = store.scanWithMultiValues(Some(startKey), endKey) + assert(!iter.hasNext) + iter.close() + } finally { + store.abort() + } + } + } + + test(s"scan with prefix encoder: None startKey scans from beginning " + + s"(encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "prefix", + useMultipleValuesPerKey = true, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + // Insert entries at various timestamps + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.merge(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) + } + + // Scan from beginning (None) up to 301 (exclusive), covering [..300] + val endKey = keyAndTimestampToRow("key1", 1, 301L) + val iter = store.scanWithMultiValues(None, endKey) + val results = iter.map(_.key.getLong(2)).toList + iter.close() + + assert(results === Seq(100L, 200L, 300L)) + } finally { + store.abort() + } + } + } + + test(s"scan with prefix encoder: boundary safety check (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "prefix", + useMultipleValuesPerKey = true, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + val timestamps = Seq(100L, 200L, 300L) + timestamps.foreach { ts => + store.merge(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) + } + store.merge(keyAndTimestampToRow("key2", 2, 150L), valueToRow(150)) + + // Scan with endKey at timestamp 201 with dummyKey - should include + // everything up to timestamp 200 regardless of join key + val endKey = keyAndTimestampToRow("key1", 1, 201L) + val iter = store.scanWithMultiValues(None, endKey) + val results = iter.map { pair => + (pair.key.getString(0), pair.key.getLong(2)) + }.toList + iter.close() + + // Should include key1@100, key2@150, key1@200 + assert(results.length === 3) + assert(results.map(_._2) === Seq(100L, 150L, 200L)) + } finally { + store.abort() + } + } + } + + test(s"scan single-value variant (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "postfix", + useMultipleValuesPerKey = false, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.put(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) + } + + val startKey = keyAndTimestampToRow("key1", 1, 200L) + val endKey = keyAndTimestampToRow("key1", 1, 401L) + val iter = store.scan(Some(startKey), endKey) + val results = iter.map { pair => + (pair.key.getLong(2), pair.value.getInt(0)) + }.toList + iter.close() + + assert(results.map(_._1) === Seq(200L, 300L, 400L)) + assert(results.map(_._2) === Seq(200, 300, 400)) + } finally { + store.abort() + } + } + } + + test(s"scan with diverse timestamps and bounded range (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "postfix", + useMultipleValuesPerKey = false, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => + store.put(keyAndTimestampToRow("key1", 1, ts), valueToRow(idx)) + } + + // Scan negative range only: [-300, 0) + val startKey = keyAndTimestampToRow("key1", 1, -300L) + val endKey = keyAndTimestampToRow("key1", 1, 0L) + val iter = store.scan(Some(startKey), endKey) + val results = iter.map(_.key.getLong(2)).toList + iter.close() + + val expected = diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted + assert(results === expected) + } finally { + store.abort() + } + } + } + } + // Helper methods to create test data private val keyProjection = UnsafeProjection.create(keySchema) private val keyAndTimestampProjection = UnsafeProjection.create( From 07e241d082dcf2d55a28945f7c8cee1829d9e0eb Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 3 Apr 2026 16:32:04 +0900 Subject: [PATCH 02/12] fix --- .../stateful/join/SymmetricHashJoinStateManager.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index b0b019bac5860..71f29d7d69623 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -660,7 +660,7 @@ class SymmetricHashJoinStateManagerV4( private val iter = if (minTs == Long.MinValue && maxTs == Long.MaxValue) { stateStore.prefixScanWithMultiValues(key, colFamilyName) } else { - val startKeyRow = createKeyRow(key, minTs) + val startKeyRow = createKeyRow(key, minTs).copy() val endKeyRow = createKeyRow(key, maxTs + 1) stateStore.scanWithMultiValues(Some(startKeyRow), endKeyRow, colFamilyName) } @@ -774,8 +774,9 @@ class SymmetricHashJoinStateManagerV4( } private lazy val dummyKeyRow: UnsafeRow = { + val defaultValues = keySchema.fields.map(f => Literal.default(f.dataType).eval()) val projection = UnsafeProjection.create(keySchema) - projection(new GenericInternalRow(keySchema.length)).copy() + projection(new GenericInternalRow(defaultValues)).copy() } case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int) From ef84c49ef84850b8a4170f2cabd6a3801e98794f Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 4 Apr 2026 08:26:26 +0900 Subject: [PATCH 03/12] change the scan's endKey to option --- .../join/SymmetricHashJoinStateManager.scala | 4 +-- .../execution/streaming/state/RocksDB.scala | 25 ++++++++++++------- .../state/RocksDBStateStoreProvider.scala | 8 +++--- .../streaming/state/StateStore.scala | 14 ++++++----- ...sDBStateStoreCheckpointFormatV2Suite.scala | 4 +-- ...cksDBTimestampEncoderOperationsSuite.scala | 14 +++++------ 6 files changed, 39 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index 71f29d7d69623..bcb341b519caf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -662,7 +662,7 @@ class SymmetricHashJoinStateManagerV4( } else { val startKeyRow = createKeyRow(key, minTs).copy() val endKeyRow = createKeyRow(key, maxTs + 1) - stateStore.scanWithMultiValues(Some(startKeyRow), endKeyRow, colFamilyName) + stateStore.scanWithMultiValues(Some(startKeyRow), Some(endKeyRow), colFamilyName) } private var currentTs = -1L @@ -784,7 +784,7 @@ class SymmetricHashJoinStateManagerV4( // NOTE: This assumes we consume the whole iterator to trigger completion. def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = { val endKeyRow = createKeyRow(dummyKeyRow, endTimestamp + 1) - val evictIterator = stateStore.scanWithMultiValues(None, endKeyRow, colFamilyName) + val evictIterator = stateStore.scanWithMultiValues(None, Some(endKeyRow), colFamilyName) new NextIterator[EvictedKeysResult]() { var currentKeyRow: UnsafeRow = null var currentEventTime: Long = -1L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index ec879a44f213e..4ad3a662b4d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -1667,20 +1667,27 @@ class RocksDB( * * @param startKey None to seek to the beginning of the column family, * or Some(key) to seek to the given start position (inclusive). - * @param endKey The exclusive upper bound for the scan (encoded key bytes). + * @param endKey None to scan to the end of the column family, + * or Some(key) as the exclusive upper bound for the scan (encoded key bytes). * @param cfName The column family name. * @return An iterator of ByteArrayPairs in the given range. */ def scan( startKey: Option[Array[Byte]], - endKey: Array[Byte], + endKey: Option[Array[Byte]], cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): NextIterator[ByteArrayPair] = { updateMemoryUsageIfNeeded() - val updatedEndKey = if (useColumnFamilies) { - encodeStateRowWithPrefix(endKey, cfName) - } else { - endKey + val upperBoundBytes: Option[Array[Byte]] = endKey match { + case Some(key) => + Some(if (useColumnFamilies) encodeStateRowWithPrefix(key, cfName) else key) + case None => + if (useColumnFamilies) { + val cfPrefix = encodeStateRowWithPrefix(Array.emptyByteArray, cfName) + RocksDB.prefixUpperBound(cfPrefix) + } else { + None + } } val seekTarget = startKey match { @@ -1691,9 +1698,9 @@ class RocksDB( else null } - val upperBoundSlice = new Slice(updatedEndKey) + val upperBoundSlice = upperBoundBytes.map(new Slice(_)) val scanReadOptions = new ReadOptions() - scanReadOptions.setIterateUpperBound(upperBoundSlice) + upperBoundSlice.foreach(scanReadOptions.setIterateUpperBound) val iter = db.newIterator(scanReadOptions) if (seekTarget != null) { @@ -1705,7 +1712,7 @@ class RocksDB( def closeResources(): Unit = { iter.close() scanReadOptions.close() - upperBoundSlice.close() + upperBoundSlice.foreach(_.close()) } Option(TaskContext.get()).foreach { tc => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 7485d679d2333..7629ffe7d4325 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -551,14 +551,14 @@ private[sql] class RocksDBStateStoreProvider override def scan( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { validateAndTransitionState(UPDATE) verifyColFamilyOperations("scan", colFamilyName) val kvEncoder = keyValueEncoderMap.get(colFamilyName) val encodedStartKey = startKey.map(kvEncoder._1.encodeKey) - val encodedEndKey = kvEncoder._1.encodeKey(endKey) + val encodedEndKey = endKey.map(kvEncoder._1.encodeKey) val rowPair = new UnsafeRowPair() val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName) @@ -573,7 +573,7 @@ private[sql] class RocksDBStateStoreProvider override def scanWithMultiValues( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { validateAndTransitionState(UPDATE) verifyColFamilyOperations("scanWithMultiValues", colFamilyName) @@ -585,7 +585,7 @@ private[sql] class RocksDBStateStoreProvider " which supports multiple values for a single key") val encodedStartKey = startKey.map(kvEncoder._1.encodeKey) - val encodedEndKey = kvEncoder._1.encodeKey(endKey) + val encodedEndKey = endKey.map(kvEncoder._1.encodeKey) val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName) val rowPair = new UnsafeRowPair() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 89b540ee1d178..9c4e2f06587a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -188,7 +188,8 @@ trait ReadStateStore { * * @param startKey None to scan from the beginning of the column family, * or Some(key) to seek to the given start position (inclusive). - * @param endKey The exclusive upper bound for the scan. + * @param endKey None to scan to the end of the column family, + * or Some(key) as the exclusive upper bound for the scan. * @param colFamilyName The column family name. * * Callers must ensure the column family's key encoder produces lexicographically ordered @@ -197,7 +198,7 @@ trait ReadStateStore { */ def scan( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) : StateStoreIterator[UnsafeRowPair] = { throw StateStoreErrors.unsupportedOperationException("scan", "") @@ -208,7 +209,8 @@ trait ReadStateStore { * * @param startKey None to scan from the beginning of the column family, * or Some(key) to seek to the given start position (inclusive). - * @param endKey The exclusive upper bound for the scan. + * @param endKey None to scan to the end of the column family, + * or Some(key) as the exclusive upper bound for the scan. * @param colFamilyName The column family name. * * Callers must ensure the column family's key encoder produces lexicographically ordered @@ -220,7 +222,7 @@ trait ReadStateStore { */ def scanWithMultiValues( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) : StateStoreIterator[UnsafeRowPair] = { throw StateStoreErrors.unsupportedOperationException("scanWithMultiValues", "") @@ -456,14 +458,14 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore { override def scan( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { store.scan(startKey, endKey, colFamilyName) } override def scanWithMultiValues( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { store.scanWithMultiValues(startKey, endKey, colFamilyName) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index 71984a380f9b4..dca9677a6f7bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -174,14 +174,14 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta override def scan( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { innerStore.scan(startKey, endKey, colFamilyName) } override def scanWithMultiValues( startKey: Option[UnsafeRow], - endKey: UnsafeRow, + endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { innerStore.scanWithMultiValues(startKey, endKey, colFamilyName) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala index 0e001cfedf58c..392c71eee432b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala @@ -591,7 +591,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession val prefixKey = keyToRow("key1", 1) val startKey = keyAndTimestampToRow("key1", 1, 0L) val endKey = keyAndTimestampToRow("key1", 1, 1001L) - val iter = store.scanWithMultiValues(Some(startKey), endKey) + val iter = store.scanWithMultiValues(Some(startKey), Some(endKey)) val results = iter.map { pair => (pair.key.getLong(2), pair.value.getInt(0)) @@ -628,7 +628,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // Scan with Some(startKey) covering full range val startKey = keyAndTimestampToRow("key1", 1, Long.MinValue) val endKey = keyAndTimestampToRow("key1", 1, Long.MaxValue) - val iter = store.scanWithMultiValues(Some(startKey), endKey) + val iter = store.scanWithMultiValues(Some(startKey), Some(endKey)) val results = iter.map(_.key.getLong(2)).toList iter.close() @@ -657,7 +657,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // Scan range that contains no entries val startKey = keyAndTimestampToRow("key1", 1, 1500L) val endKey = keyAndTimestampToRow("key1", 1, 1600L) - val iter = store.scanWithMultiValues(Some(startKey), endKey) + val iter = store.scanWithMultiValues(Some(startKey), Some(endKey)) assert(!iter.hasNext) iter.close() } finally { @@ -685,7 +685,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // Scan from beginning (None) up to 301 (exclusive), covering [..300] val endKey = keyAndTimestampToRow("key1", 1, 301L) - val iter = store.scanWithMultiValues(None, endKey) + val iter = store.scanWithMultiValues(None, Some(endKey)) val results = iter.map(_.key.getLong(2)).toList iter.close() @@ -715,7 +715,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // Scan with endKey at timestamp 201 with dummyKey - should include // everything up to timestamp 200 regardless of join key val endKey = keyAndTimestampToRow("key1", 1, 201L) - val iter = store.scanWithMultiValues(None, endKey) + val iter = store.scanWithMultiValues(None, Some(endKey)) val results = iter.map { pair => (pair.key.getString(0), pair.key.getLong(2)) }.toList @@ -747,7 +747,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession val startKey = keyAndTimestampToRow("key1", 1, 200L) val endKey = keyAndTimestampToRow("key1", 1, 401L) - val iter = store.scan(Some(startKey), endKey) + val iter = store.scan(Some(startKey), Some(endKey)) val results = iter.map { pair => (pair.key.getLong(2), pair.value.getInt(0)) }.toList @@ -778,7 +778,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // Scan negative range only: [-300, 0) val startKey = keyAndTimestampToRow("key1", 1, -300L) val endKey = keyAndTimestampToRow("key1", 1, 0L) - val iter = store.scan(Some(startKey), endKey) + val iter = store.scan(Some(startKey), Some(endKey)) val results = iter.map(_.key.getLong(2)).toList iter.close() From 158dff8d6dce01323fe5e4304454055155127de4 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 6 Apr 2026 14:43:24 +0900 Subject: [PATCH 04/12] Roll back the usage on stream-stream join, add test for range scan --- .../join/SymmetricHashJoinStateManager.scala | 48 ++- .../state/RocksDBStateStoreSuite.scala | 333 ++++++++++++++++++ 2 files changed, 356 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index bcb341b519caf..fc2a69312fe79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -27,7 +27,7 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX, STATE_STORE_ID} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GenericInternalRow, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, WatermarkSupport} @@ -647,25 +647,17 @@ class SymmetricHashJoinStateManagerV4( /** * Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp. - * - * When a bounded range is provided, leverages RocksDB's native seek and upper bound via - * [[StateStore.scanWithMultiValues]] to avoid reading entries outside the range. - * Falls back to [[StateStore.prefixScanWithMultiValues]] when the full range is requested. + * Skips entries before minTs and stops iterating past maxTs (timestamps are sorted). */ def getValuesInRange( key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = { val reusableGetValuesResult = new GetValuesResult() new NextIterator[GetValuesResult] { - private val iter = if (minTs == Long.MinValue && maxTs == Long.MaxValue) { - stateStore.prefixScanWithMultiValues(key, colFamilyName) - } else { - val startKeyRow = createKeyRow(key, minTs).copy() - val endKeyRow = createKeyRow(key, maxTs + 1) - stateStore.scanWithMultiValues(Some(startKeyRow), Some(endKeyRow), colFamilyName) - } + private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName) private var currentTs = -1L + private var pastUpperBound = false private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]() private def flushAccumulated(): GetValuesResult = { @@ -683,13 +675,18 @@ class SymmetricHashJoinStateManagerV4( @tailrec override protected def getNext(): GetValuesResult = { - if (!iter.hasNext) { + if (pastUpperBound || !iter.hasNext) { flushAccumulated() } else { val unsafeRowPair = iter.next() val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key) - if (currentTs == -1L || currentTs == ts) { + if (ts > maxTs) { + pastUpperBound = true + getNext() + } else if (ts < minTs) { + getNext() + } else if (currentTs == -1L || currentTs == ts) { currentTs = ts valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value) getNext() @@ -773,18 +770,11 @@ class SymmetricHashJoinStateManagerV4( stateStore.remove(createKeyRow(key, timestamp), colFamilyName) } - private lazy val dummyKeyRow: UnsafeRow = { - val defaultValues = keySchema.fields.map(f => Literal.default(f.dataType).eval()) - val projection = UnsafeProjection.create(keySchema) - projection(new GenericInternalRow(defaultValues)).copy() - } - case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int) // NOTE: This assumes we consume the whole iterator to trigger completion. def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = { - val endKeyRow = createKeyRow(dummyKeyRow, endTimestamp + 1) - val evictIterator = stateStore.scanWithMultiValues(None, Some(endKeyRow), colFamilyName) + val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName) new NextIterator[EvictedKeysResult]() { var currentKeyRow: UnsafeRow = null var currentEventTime: Long = -1L @@ -799,19 +789,25 @@ class SymmetricHashJoinStateManagerV4( val ts = TimestampKeyStateEncoder.extractTimestamp(kv.key) if (keyRow == currentKeyRow && ts == currentEventTime) { + // new value with same (key, ts) count += 1 } else if (ts > endTimestamp) { - // Safety check for boundary edge case: a small number of entries at exactly - // endTimestamp + 1 may leak through the upper bound because the encoded end key - // includes a join key suffix. + // we found the timestamp beyond the range - we shouldn't continue further isBeyondUpperBound = true + + // We don't need to construct the last (key, ts) into EvictedKeysResult - the code + // after loop will handle that if there is leftover. That said, we do not reset the + // current (key, ts) info here. } else if (currentKeyRow == null && currentEventTime == -1L) { + // first value to process currentKeyRow = keyRow.copy() currentEventTime = ts count = 1 } else { + // construct the last (key, ts) into EvictedKeysResult ret = EvictedKeysResult(currentKeyRow, currentEventTime, count) + // register the next (key, ts) to process currentKeyRow = keyRow.copy() currentEventTime = ts count = 1 @@ -821,8 +817,10 @@ class SymmetricHashJoinStateManagerV4( if (ret != null) { ret } else if (count > 0) { + // there is a final leftover (key, ts) to return ret = EvictedKeysResult(currentKeyRow, currentEventTime, count) + // we shouldn't continue further currentKeyRow = null currentEventTime = -1L count = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 3e4b4b7320f53..386e283ccc67f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -1629,6 +1629,339 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } + private val diverseTimestamps = Seq(931L, 8000L, 452300L, 4200L, -1L, 90L, 1L, 2L, 8L, + -230L, -14569L, -92L, -7434253L, 35L, 6L, 9L, -323L, 5L, + -32L, -64L, -256L, 64L, 32L, 1024L, 4096L, 0L) + + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - scan bounded range", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) + } + + val startKey = dataToKeyRowWithRangeScan(200L, "a") + val endKey = dataToKeyRowWithRangeScan(401L, "a") + val iter = store.scan(Some(startKey), Some(endKey), cfName) + val results = iter.map { pair => + (pair.key.getLong(0), pair.value.getInt(0)) + }.toList + iter.close() + + assert(results.map(_._1) === Seq(200L, 300L, 400L)) + assert(results.map(_._2) === Seq(200, 300, 400)) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scan with None startKey scans from beginning", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) + } + + val endKey = dataToKeyRowWithRangeScan(301L, "a") + val iter = store.scan(None, Some(endKey), cfName) + val results = iter.map(_.key.getLong(0)).toList + iter.close() + + assert(results === Seq(100L, 200L, 300L)) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scan with None endKey scans to end", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) + } + + val startKey = dataToKeyRowWithRangeScan(300L, "a") + val iter = store.scan(Some(startKey), None, cfName) + val results = iter.map(_.key.getLong(0)).toList + iter.close() + + assert(results === Seq(300L, 400L, 500L)) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - scan empty range", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + store.put(dataToKeyRowWithRangeScan(100L, "a"), dataToValueRow(100), cfName) + store.put(dataToKeyRowWithRangeScan(500L, "a"), dataToValueRow(500), cfName) + + val startKey = dataToKeyRowWithRangeScan(200L, "a") + val endKey = dataToKeyRowWithRangeScan(300L, "a") + val iter = store.scan(Some(startKey), Some(endKey), cfName) + assert(!iter.hasNext) + iter.close() + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scan with diverse timestamps bounded range", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => + store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(idx), cfName) + } + + // Scan negative range: [-300, 0) + val startKey = dataToKeyRowWithRangeScan(-300L, "a") + val endKey = dataToKeyRowWithRangeScan(0L, "a") + val iter = store.scan(Some(startKey), Some(endKey), cfName) + val results = iter.map(_.key.getLong(0)).toList + iter.close() + + val expected = diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted + assert(results === expected) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scan with multiple key2 values within same key1 range", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + Seq("a", "b", "c").foreach { key2 => + Seq(100L, 200L, 300L).foreach { ts => + store.put(dataToKeyRowWithRangeScan(ts, key2), dataToValueRow(ts.toInt), cfName) + } + } + + val startKey = dataToKeyRowWithRangeScan(100L, "a") + val endKey = dataToKeyRowWithRangeScan(201L, "a") + val iter = store.scan(Some(startKey), Some(endKey), cfName) + val results = iter.map { pair => + (pair.key.getLong(0), pair.key.getUTF8String(1).toString) + }.toList + iter.close() + + assert(results.map(_._1).distinct.sorted === Seq(100L, 200L)) + assert(results.length === 6) // 3 key2 values x 2 key1 values + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scanWithMultiValues bounded range", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + // Multiple values per key requires column families + if (colFamiliesEnabled) { + tryWithProviderResource(newStoreProvider( + StateStoreId(newDir(), Random.nextInt(), 0), + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + keySchema = keySchemaWithRangeScan, + useColumnFamilies = colFamiliesEnabled, + useMultipleValuesPerKey = true)) { provider => + val store = provider.getStore(0) + try { + val cfName = "testColFamily" + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + useMultipleValuesPerKey = true) + + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.putList(dataToKeyRowWithRangeScan(ts, "a"), + Array(dataToValueRow(ts.toInt), dataToValueRow(ts.toInt + 1)), cfName) + } + + val startKey = dataToKeyRowWithRangeScan(200L, "a") + val endKey = dataToKeyRowWithRangeScan(401L, "a") + val iter = store.scanWithMultiValues(Some(startKey), Some(endKey), cfName) + val results = iter.map { pair => + (pair.key.getLong(0), pair.value.getInt(0)) + }.toList + iter.close() + + val resultTimestamps = results.map(_._1).distinct + assert(resultTimestamps === Seq(200L, 300L, 400L)) + assert(results.length === 6) // 3 timestamps x 2 values each + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scanWithMultiValues with None startKey", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + if (colFamiliesEnabled) { + tryWithProviderResource(newStoreProvider( + StateStoreId(newDir(), Random.nextInt(), 0), + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + keySchema = keySchemaWithRangeScan, + useColumnFamilies = colFamiliesEnabled, + useMultipleValuesPerKey = true)) { provider => + val store = provider.getStore(0) + try { + val cfName = "testColFamily" + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + useMultipleValuesPerKey = true) + + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.merge(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) + } + + val endKey = dataToKeyRowWithRangeScan(301L, "a") + val iter = store.scanWithMultiValues(None, Some(endKey), cfName) + val results = iter.map(_.key.getLong(0)).toList + iter.close() + + assert(results === Seq(100L, 200L, 300L)) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scanWithMultiValues with diverse timestamps", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + if (colFamiliesEnabled) { + tryWithProviderResource(newStoreProvider( + StateStoreId(newDir(), Random.nextInt(), 0), + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + keySchema = keySchemaWithRangeScan, + useColumnFamilies = colFamiliesEnabled, + useMultipleValuesPerKey = true)) { provider => + val store = provider.getStore(0) + try { + val cfName = "testColFamily" + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + useMultipleValuesPerKey = true) + + diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => + store.putList(dataToKeyRowWithRangeScan(ts, "a"), + Array(dataToValueRow(idx * 10), dataToValueRow(idx * 10 + 1)), cfName) + } + + // Scan [0, 1000] (inclusive via endKey = 1001) + val startKey = dataToKeyRowWithRangeScan(0L, "a") + val endKey = dataToKeyRowWithRangeScan(1001L, "a") + val iter = store.scanWithMultiValues(Some(startKey), Some(endKey), cfName) + val results = iter.map { pair => + (pair.key.getLong(0), pair.value.getInt(0)) + }.toList + iter.close() + + val expectedTimestamps = diverseTimestamps.filter(ts => ts >= 0 && ts <= 1000).sorted + val resultTimestamps = results.map(_._1).distinct + assert(resultTimestamps === expectedTimestamps) + assert(results.length === expectedTimestamps.length * 2) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + } + testWithColumnFamiliesAndEncodingTypes( "rocksdb key and value schema encoders for column families", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => From a494b318a1942c9f3a64316fbfdebd20bdce1fd6 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 6 Apr 2026 20:43:40 +0900 Subject: [PATCH 05/12] Reflect self-review comments --- .../state/RocksDBStateStoreSuite.scala | 293 ++++-------------- ...cksDBTimestampEncoderOperationsSuite.scala | 210 ++++--------- 2 files changed, 124 insertions(+), 379 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 386e283ccc67f..7e373a263ab74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -1633,7 +1633,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid -230L, -14569L, -92L, -7434253L, 35L, 6L, 9L, -323L, 5L, -32L, -64L, -256L, 64L, 32L, 1024L, 4096L, 0L) - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - scan bounded range", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, @@ -1648,151 +1648,50 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) } - val timestamps = Seq(100L, 200L, 300L, 400L, 500L) - timestamps.foreach { ts => + diverseTimestamps.foreach { ts => store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) } - val startKey = dataToKeyRowWithRangeScan(200L, "a") - val endKey = dataToKeyRowWithRangeScan(401L, "a") - val iter = store.scan(Some(startKey), Some(endKey), cfName) - val results = iter.map { pair => + // Bounded positive range [0, 100) + val boundedIter = store.scan( + Some(dataToKeyRowWithRangeScan(0L, "a")), + Some(dataToKeyRowWithRangeScan(100L, "a")), cfName) + val boundedResults = boundedIter.map { pair => (pair.key.getLong(0), pair.value.getInt(0)) }.toList - iter.close() - - assert(results.map(_._1) === Seq(200L, 300L, 400L)) - assert(results.map(_._2) === Seq(200, 300, 400)) - } finally { - if (!store.hasCommitted) store.abort() - } - } - } - - testWithColumnFamiliesAndEncodingTypes( - "rocksdb range scan - scan with None startKey scans from beginning", - TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => - - tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - colFamiliesEnabled)) { provider => - val store = provider.getStore(0) - try { - val cfName = if (colFamiliesEnabled) "testColFamily" else "default" - if (colFamiliesEnabled) { - store.createColFamilyIfAbsent(cfName, - keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) - } - - val timestamps = Seq(100L, 200L, 300L, 400L, 500L) - timestamps.foreach { ts => - store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) - } - - val endKey = dataToKeyRowWithRangeScan(301L, "a") - val iter = store.scan(None, Some(endKey), cfName) - val results = iter.map(_.key.getLong(0)).toList - iter.close() - - assert(results === Seq(100L, 200L, 300L)) - } finally { - if (!store.hasCommitted) store.abort() - } - } - } - - testWithColumnFamiliesAndEncodingTypes( - "rocksdb range scan - scan with None endKey scans to end", - TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => - - tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - colFamiliesEnabled)) { provider => - val store = provider.getStore(0) - try { - val cfName = if (colFamiliesEnabled) "testColFamily" else "default" - if (colFamiliesEnabled) { - store.createColFamilyIfAbsent(cfName, - keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) - } - - val timestamps = Seq(100L, 200L, 300L, 400L, 500L) - timestamps.foreach { ts => - store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) - } - - val startKey = dataToKeyRowWithRangeScan(300L, "a") - val iter = store.scan(Some(startKey), None, cfName) - val results = iter.map(_.key.getLong(0)).toList - iter.close() - - assert(results === Seq(300L, 400L, 500L)) - } finally { - if (!store.hasCommitted) store.abort() - } - } - } - - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - scan empty range", - TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => - - tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - colFamiliesEnabled)) { provider => - val store = provider.getStore(0) - try { - val cfName = if (colFamiliesEnabled) "testColFamily" else "default" - if (colFamiliesEnabled) { - store.createColFamilyIfAbsent(cfName, - keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) - } - - store.put(dataToKeyRowWithRangeScan(100L, "a"), dataToValueRow(100), cfName) - store.put(dataToKeyRowWithRangeScan(500L, "a"), dataToValueRow(500), cfName) - - val startKey = dataToKeyRowWithRangeScan(200L, "a") - val endKey = dataToKeyRowWithRangeScan(300L, "a") - val iter = store.scan(Some(startKey), Some(endKey), cfName) - assert(!iter.hasNext) - iter.close() - } finally { - if (!store.hasCommitted) store.abort() - } - } - } - - testWithColumnFamiliesAndEncodingTypes( - "rocksdb range scan - scan with diverse timestamps bounded range", - TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => - - tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - colFamiliesEnabled)) { provider => - val store = provider.getStore(0) - try { - val cfName = if (colFamiliesEnabled) "testColFamily" else "default" - if (colFamiliesEnabled) { - store.createColFamilyIfAbsent(cfName, - keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) - } - - diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => - store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(idx), cfName) - } - - // Scan negative range: [-300, 0) - val startKey = dataToKeyRowWithRangeScan(-300L, "a") - val endKey = dataToKeyRowWithRangeScan(0L, "a") - val iter = store.scan(Some(startKey), Some(endKey), cfName) - val results = iter.map(_.key.getLong(0)).toList - iter.close() - - val expected = diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted - assert(results === expected) + boundedIter.close() + val expectedBoundedTs = diverseTimestamps.filter(ts => ts >= 0 && ts < 100).sorted + assert(boundedResults.map(_._1) === expectedBoundedTs) + assert(boundedResults.map(_._2) === expectedBoundedTs.map(_.toInt)) + + // None startKey scans from beginning to 0 + val noneStartIter = store.scan( + None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val noneStartResults = noneStartIter.map(_.key.getLong(0)).toList + noneStartIter.close() + assert(noneStartResults === diverseTimestamps.filter(_ < 0).sorted) + + // None endKey scans from 1000 to end + val noneEndIter = store.scan( + Some(dataToKeyRowWithRangeScan(1000L, "a")), None, cfName) + val noneEndResults = noneEndIter.map(_.key.getLong(0)).toList + noneEndIter.close() + assert(noneEndResults === diverseTimestamps.filter(_ >= 1000).sorted) + + // Empty range [10, 31) - no entries between 9 and 32 + val emptyIter = store.scan( + Some(dataToKeyRowWithRangeScan(10L, "a")), + Some(dataToKeyRowWithRangeScan(31L, "a")), cfName) + assert(!emptyIter.hasNext) + emptyIter.close() + + // Bounded negative range [-300, 0) + val negIter = store.scan( + Some(dataToKeyRowWithRangeScan(-300L, "a")), + Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val negResults = negIter.map(_.key.getLong(0)).toList + negIter.close() + assert(negResults === diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted) } finally { if (!store.hasCommitted) store.abort() } @@ -1838,88 +1737,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } testWithColumnFamiliesAndEncodingTypes( - "rocksdb range scan - scanWithMultiValues bounded range", - TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => - - // Multiple values per key requires column families - if (colFamiliesEnabled) { - tryWithProviderResource(newStoreProvider( - StateStoreId(newDir(), Random.nextInt(), 0), - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - keySchema = keySchemaWithRangeScan, - useColumnFamilies = colFamiliesEnabled, - useMultipleValuesPerKey = true)) { provider => - val store = provider.getStore(0) - try { - val cfName = "testColFamily" - store.createColFamilyIfAbsent(cfName, - keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - useMultipleValuesPerKey = true) - - val timestamps = Seq(100L, 200L, 300L, 400L, 500L) - timestamps.foreach { ts => - store.putList(dataToKeyRowWithRangeScan(ts, "a"), - Array(dataToValueRow(ts.toInt), dataToValueRow(ts.toInt + 1)), cfName) - } - - val startKey = dataToKeyRowWithRangeScan(200L, "a") - val endKey = dataToKeyRowWithRangeScan(401L, "a") - val iter = store.scanWithMultiValues(Some(startKey), Some(endKey), cfName) - val results = iter.map { pair => - (pair.key.getLong(0), pair.value.getInt(0)) - }.toList - iter.close() - - val resultTimestamps = results.map(_._1).distinct - assert(resultTimestamps === Seq(200L, 300L, 400L)) - assert(results.length === 6) // 3 timestamps x 2 values each - } finally { - if (!store.hasCommitted) store.abort() - } - } - } - } - - testWithColumnFamiliesAndEncodingTypes( - "rocksdb range scan - scanWithMultiValues with None startKey", - TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => - - if (colFamiliesEnabled) { - tryWithProviderResource(newStoreProvider( - StateStoreId(newDir(), Random.nextInt(), 0), - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - keySchema = keySchemaWithRangeScan, - useColumnFamilies = colFamiliesEnabled, - useMultipleValuesPerKey = true)) { provider => - val store = provider.getStore(0) - try { - val cfName = "testColFamily" - store.createColFamilyIfAbsent(cfName, - keySchemaWithRangeScan, valueSchema, - RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), - useMultipleValuesPerKey = true) - - val timestamps = Seq(100L, 200L, 300L, 400L, 500L) - timestamps.foreach { ts => - store.merge(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) - } - - val endKey = dataToKeyRowWithRangeScan(301L, "a") - val iter = store.scanWithMultiValues(None, Some(endKey), cfName) - val results = iter.map(_.key.getLong(0)).toList - iter.close() - - assert(results === Seq(100L, 200L, 300L)) - } finally { - if (!store.hasCommitted) store.abort() - } - } - } - } - - testWithColumnFamiliesAndEncodingTypes( - "rocksdb range scan - scanWithMultiValues with diverse timestamps", + "rocksdb range scan - scanWithMultiValues", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => if (colFamiliesEnabled) { @@ -1942,19 +1760,30 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid Array(dataToValueRow(idx * 10), dataToValueRow(idx * 10 + 1)), cfName) } - // Scan [0, 1000] (inclusive via endKey = 1001) - val startKey = dataToKeyRowWithRangeScan(0L, "a") - val endKey = dataToKeyRowWithRangeScan(1001L, "a") - val iter = store.scanWithMultiValues(Some(startKey), Some(endKey), cfName) - val results = iter.map { pair => + // Bounded range [0, 1001) + val boundedIter = store.scanWithMultiValues( + Some(dataToKeyRowWithRangeScan(0L, "a")), + Some(dataToKeyRowWithRangeScan(1001L, "a")), cfName) + val boundedResults = boundedIter.map { pair => (pair.key.getLong(0), pair.value.getInt(0)) }.toList - iter.close() + boundedIter.close() val expectedTimestamps = diverseTimestamps.filter(ts => ts >= 0 && ts <= 1000).sorted - val resultTimestamps = results.map(_._1).distinct - assert(resultTimestamps === expectedTimestamps) - assert(results.length === expectedTimestamps.length * 2) + assert(boundedResults.map(_._1).distinct === expectedTimestamps) + val expectedValues = diverseTimestamps.zipWithIndex + .filter { case (ts, _) => ts >= 0 && ts <= 1000 } + .sortBy(_._1) + .flatMap { case (_, idx) => Seq(idx * 10, idx * 10 + 1) } + assert(boundedResults.map(_._2) === expectedValues) + + // None startKey scans from beginning to 0 + val noneStartIter = store.scanWithMultiValues( + None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val noneStartResults = noneStartIter.map(_.key.getLong(0)).toList + noneStartIter.close() + + assert(noneStartResults.distinct === diverseTimestamps.filter(_ < 0).sorted) } finally { if (!store.hasCommitted) store.abort() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala index 392c71eee432b..580d6e4fd12c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala @@ -567,7 +567,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession } Seq("unsaferow", "avro").foreach { encoding => - test(s"scan with postfix encoder: bounded range (encoding = $encoding)") { + test(s"scan with postfix encoder (encoding = $encoding)") { tryWithProviderResource( newStoreProviderWithTimestampEncoder( encoderType = "postfix", @@ -577,97 +577,54 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession val store = provider.getStore(0) try { - // Insert entries for key1 at various timestamps diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => val keyRow = keyAndTimestampToRow("key1", 1, ts) store.putList(keyRow, Array(valueToRow(idx * 10), valueToRow(idx * 10 + 1))) } - // Insert entries for key2 to verify isolation + // key2 entry to verify prefix isolation store.putList(keyAndTimestampToRow("key2", 1, 500L), Array(valueToRow(999))) - // Scan key1 in range [0, 1000] (inclusive via endKey = 1001) - val prefixKey = keyToRow("key1", 1) - val startKey = keyAndTimestampToRow("key1", 1, 0L) - val endKey = keyAndTimestampToRow("key1", 1, 1001L) - val iter = store.scanWithMultiValues(Some(startKey), Some(endKey)) - - val results = iter.map { pair => + // Bounded range [0, 1001) + val boundedIter = store.scanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, 0L)), + Some(keyAndTimestampToRow("key1", 1, 1001L))) + val boundedResults = boundedIter.map { pair => (pair.key.getLong(2), pair.value.getInt(0)) }.toList - iter.close() + boundedIter.close() - // Timestamps in [0, 1000]: 0, 1, 2, 5, 6, 8, 9, 32, 35, 64, 90, 931 val expectedTimestamps = diverseTimestamps.filter(ts => ts >= 0 && ts <= 1000).sorted - val resultTimestamps = results.map(_._1).distinct - assert(resultTimestamps === expectedTimestamps) - assert(results.length === expectedTimestamps.length * 2) // 2 values per timestamp - } finally { - store.abort() - } - } - } - - test(s"scan with postfix encoder: full range falls back correctly (encoding = $encoding)") { - tryWithProviderResource( - newStoreProviderWithTimestampEncoder( - encoderType = "postfix", - useMultipleValuesPerKey = true, - dataEncoding = encoding) - ) { provider => - val store = provider.getStore(0) - - try { - val timestamps = Seq(100L, 200L, 300L) - timestamps.foreach { ts => - store.putList(keyAndTimestampToRow("key1", 1, ts), - Array(valueToRow(ts.toInt))) - } - - // Scan with Some(startKey) covering full range - val startKey = keyAndTimestampToRow("key1", 1, Long.MinValue) - val endKey = keyAndTimestampToRow("key1", 1, Long.MaxValue) - val iter = store.scanWithMultiValues(Some(startKey), Some(endKey)) - val results = iter.map(_.key.getLong(2)).toList - iter.close() - - assert(results === timestamps) + assert(boundedResults.map(_._1).distinct === expectedTimestamps) + val expectedValues = diverseTimestamps.zipWithIndex + .filter { case (ts, _) => ts >= 0 && ts <= 1000 } + .sortBy(_._1) + .flatMap { case (_, idx) => Seq(idx * 10, idx * 10 + 1) } + assert(boundedResults.map(_._2) === expectedValues) + + // Full range [MinValue, MaxValue) + val fullIter = store.scanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, Long.MinValue)), + Some(keyAndTimestampToRow("key1", 1, Long.MaxValue))) + val fullResults = fullIter.map(_.key.getLong(2)).toList + fullIter.close() + + assert(fullResults.distinct === diverseTimestamps.sorted) + + // Empty range [10, 31) - no diverseTimestamps entries between 9 and 32 + val emptyIter = store.scanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, 10L)), + Some(keyAndTimestampToRow("key1", 1, 31L))) + assert(!emptyIter.hasNext) + emptyIter.close() } finally { store.abort() } } } - test(s"scan with postfix encoder: empty range (encoding = $encoding)") { - tryWithProviderResource( - newStoreProviderWithTimestampEncoder( - encoderType = "postfix", - useMultipleValuesPerKey = true, - dataEncoding = encoding) - ) { provider => - val store = provider.getStore(0) - - try { - store.putList(keyAndTimestampToRow("key1", 1, 1000L), - Array(valueToRow(100))) - store.putList(keyAndTimestampToRow("key1", 1, 2000L), - Array(valueToRow(200))) - - // Scan range that contains no entries - val startKey = keyAndTimestampToRow("key1", 1, 1500L) - val endKey = keyAndTimestampToRow("key1", 1, 1600L) - val iter = store.scanWithMultiValues(Some(startKey), Some(endKey)) - assert(!iter.hasNext) - iter.close() - } finally { - store.abort() - } - } - } - - test(s"scan with prefix encoder: None startKey scans from beginning " + - s"(encoding = $encoding)") { + test(s"scan with prefix encoder (encoding = $encoding)") { tryWithProviderResource( newStoreProviderWithTimestampEncoder( encoderType = "prefix", @@ -677,53 +634,30 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession val store = provider.getStore(0) try { - // Insert entries at various timestamps val timestamps = Seq(100L, 200L, 300L, 400L, 500L) timestamps.foreach { ts => store.merge(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) } + store.merge(keyAndTimestampToRow("key2", 2, 150L), valueToRow(150)) - // Scan from beginning (None) up to 301 (exclusive), covering [..300] - val endKey = keyAndTimestampToRow("key1", 1, 301L) - val iter = store.scanWithMultiValues(None, Some(endKey)) - val results = iter.map(_.key.getLong(2)).toList - iter.close() - - assert(results === Seq(100L, 200L, 300L)) - } finally { - store.abort() - } - } - } - - test(s"scan with prefix encoder: boundary safety check (encoding = $encoding)") { - tryWithProviderResource( - newStoreProviderWithTimestampEncoder( - encoderType = "prefix", - useMultipleValuesPerKey = true, - dataEncoding = encoding) - ) { provider => - val store = provider.getStore(0) + // None startKey scans from beginning up to 301 (exclusive) + val iter1 = store.scanWithMultiValues(None, + Some(keyAndTimestampToRow("key1", 1, 301L))) + val results1 = iter1.map(_.key.getLong(2)).toList + iter1.close() - try { - val timestamps = Seq(100L, 200L, 300L) - timestamps.foreach { ts => - store.merge(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) - } - store.merge(keyAndTimestampToRow("key2", 2, 150L), valueToRow(150)) + assert(results1 === Seq(100L, 150L, 200L, 300L)) - // Scan with endKey at timestamp 201 with dummyKey - should include - // everything up to timestamp 200 regardless of join key - val endKey = keyAndTimestampToRow("key1", 1, 201L) - val iter = store.scanWithMultiValues(None, Some(endKey)) - val results = iter.map { pair => + // Boundary safety: endKey at 201, includes everything up to 200 + // regardless of join key + val iter2 = store.scanWithMultiValues(None, + Some(keyAndTimestampToRow("key1", 1, 201L))) + val results2 = iter2.map { pair => (pair.key.getString(0), pair.key.getLong(2)) }.toList - iter.close() + iter2.close() - // Should include key1@100, key2@150, key1@200 - assert(results.length === 3) - assert(results.map(_._2) === Seq(100L, 150L, 200L)) + assert(results2.map(_._2) === Seq(100L, 150L, 200L)) } finally { store.abort() } @@ -740,50 +674,32 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession val store = provider.getStore(0) try { - val timestamps = Seq(100L, 200L, 300L, 400L, 500L) - timestamps.foreach { ts => + diverseTimestamps.foreach { ts => store.put(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) } - val startKey = keyAndTimestampToRow("key1", 1, 200L) - val endKey = keyAndTimestampToRow("key1", 1, 401L) - val iter = store.scan(Some(startKey), Some(endKey)) - val results = iter.map { pair => + // Bounded positive range [0, 100) + val posIter = store.scan( + Some(keyAndTimestampToRow("key1", 1, 0L)), + Some(keyAndTimestampToRow("key1", 1, 100L))) + val posResults = posIter.map { pair => (pair.key.getLong(2), pair.value.getInt(0)) }.toList - iter.close() - - assert(results.map(_._1) === Seq(200L, 300L, 400L)) - assert(results.map(_._2) === Seq(200, 300, 400)) - } finally { - store.abort() - } - } - } - - test(s"scan with diverse timestamps and bounded range (encoding = $encoding)") { - tryWithProviderResource( - newStoreProviderWithTimestampEncoder( - encoderType = "postfix", - useMultipleValuesPerKey = false, - dataEncoding = encoding) - ) { provider => - val store = provider.getStore(0) + posIter.close() - try { - diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => - store.put(keyAndTimestampToRow("key1", 1, ts), valueToRow(idx)) - } + val expectedPosTs = diverseTimestamps.filter(ts => ts >= 0 && ts < 100).sorted + assert(posResults.map(_._1) === expectedPosTs) + assert(posResults.map(_._2) === expectedPosTs.map(_.toInt)) - // Scan negative range only: [-300, 0) - val startKey = keyAndTimestampToRow("key1", 1, -300L) - val endKey = keyAndTimestampToRow("key1", 1, 0L) - val iter = store.scan(Some(startKey), Some(endKey)) - val results = iter.map(_.key.getLong(2)).toList - iter.close() + // Bounded negative range [-300, 0) + val negIter = store.scan( + Some(keyAndTimestampToRow("key1", 1, -300L)), + Some(keyAndTimestampToRow("key1", 1, 0L))) + val negResults = negIter.map(_.key.getLong(2)).toList + negIter.close() - val expected = diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted - assert(results === expected) + val expectedNegTs = diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted + assert(negResults === expectedNegTs) } finally { store.abort() } From 0aa4813239e63d5797221437cccf8f709c9a7070 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 7 Apr 2026 06:22:03 +0900 Subject: [PATCH 06/12] Second round of self review comments --- .../state/RocksDBStateStoreSuite.scala | 57 +++++++++++- ...cksDBTimestampEncoderOperationsSuite.scala | 88 +++++++++---------- 2 files changed, 94 insertions(+), 51 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 7e373a263ab74..e6dd70c55d737 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -1633,7 +1633,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid -230L, -14569L, -92L, -7434253L, 35L, 6L, 9L, -323L, 5L, -32L, -64L, -256L, 64L, 32L, 1024L, 4096L, 0L) - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - scan", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, @@ -1664,6 +1664,18 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid assert(boundedResults.map(_._1) === expectedBoundedTs) assert(boundedResults.map(_._2) === expectedBoundedTs.map(_.toInt)) + // Exact bound: startKey is inclusive, endKey is exclusive. + // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. + // Scan [9, 90) should include 9 but exclude 90. + val exactIter = store.scan( + Some(dataToKeyRowWithRangeScan(9L, "a")), + Some(dataToKeyRowWithRangeScan(90L, "a")), cfName) + val exactResults = exactIter.map(_.key.getLong(0)).toList + exactIter.close() + assert(exactResults === diverseTimestamps.filter(ts => ts >= 9 && ts < 90).sorted) + assert(exactResults.contains(9L)) + assert(!exactResults.contains(90L)) + // None startKey scans from beginning to 0 val noneStartIter = store.scan( None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) @@ -1728,8 +1740,10 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid }.toList iter.close() - assert(results.map(_._1).distinct.sorted === Seq(100L, 200L)) - assert(results.length === 6) // 3 key2 values x 2 key1 values + val expectedResults = Seq( + (100L, "a"), (100L, "b"), (100L, "c"), + (200L, "a"), (200L, "b"), (200L, "c")) + assert(results === expectedResults) } finally { if (!store.hasCommitted) store.abort() } @@ -1777,13 +1791,48 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid .flatMap { case (_, idx) => Seq(idx * 10, idx * 10 + 1) } assert(boundedResults.map(_._2) === expectedValues) + // Exact bound: startKey is inclusive, endKey is exclusive. + // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. + val exactIter = store.scanWithMultiValues( + Some(dataToKeyRowWithRangeScan(9L, "a")), + Some(dataToKeyRowWithRangeScan(90L, "a")), cfName) + val exactResults = exactIter.map(_.key.getLong(0)).toList + exactIter.close() + val exactResultsDistinct = exactResults.distinct + assert(exactResultsDistinct === diverseTimestamps + .filter(ts => ts >= 9 && ts < 90).sorted) + assert(exactResultsDistinct.contains(9L)) + assert(!exactResultsDistinct.contains(90L)) + // None startKey scans from beginning to 0 val noneStartIter = store.scanWithMultiValues( None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) val noneStartResults = noneStartIter.map(_.key.getLong(0)).toList noneStartIter.close() - assert(noneStartResults.distinct === diverseTimestamps.filter(_ < 0).sorted) + + // None endKey scans from 1000 to end + val noneEndIter = store.scanWithMultiValues( + Some(dataToKeyRowWithRangeScan(1000L, "a")), None, cfName) + val noneEndResults = noneEndIter.map(_.key.getLong(0)).toList + noneEndIter.close() + assert(noneEndResults.distinct === diverseTimestamps.filter(_ >= 1000).sorted) + + // Empty range [10, 31) - no entries between 9 and 32 + val emptyIter = store.scanWithMultiValues( + Some(dataToKeyRowWithRangeScan(10L, "a")), + Some(dataToKeyRowWithRangeScan(31L, "a")), cfName) + assert(!emptyIter.hasNext) + emptyIter.close() + + // Bounded negative range [-300, 0) + val negIter = store.scanWithMultiValues( + Some(dataToKeyRowWithRangeScan(-300L, "a")), + Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val negResults = negIter.map(_.key.getLong(0)).toList + negIter.close() + assert(negResults.distinct === diverseTimestamps + .filter(ts => ts >= -300 && ts < 0).sorted) } finally { if (!store.hasCommitted) store.abort() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala index 580d6e4fd12c4..414a9a39872a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala @@ -591,17 +591,36 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession Some(keyAndTimestampToRow("key1", 1, 0L)), Some(keyAndTimestampToRow("key1", 1, 1001L))) val boundedResults = boundedIter.map { pair => - (pair.key.getLong(2), pair.value.getInt(0)) + (pair.key.getString(0), pair.key.getLong(2), pair.value.getInt(0)) }.toList boundedIter.close() val expectedTimestamps = diverseTimestamps.filter(ts => ts >= 0 && ts <= 1000).sorted - assert(boundedResults.map(_._1).distinct === expectedTimestamps) + assert(boundedResults.map(_._2).distinct === expectedTimestamps) val expectedValues = diverseTimestamps.zipWithIndex .filter { case (ts, _) => ts >= 0 && ts <= 1000 } .sortBy(_._1) .flatMap { case (_, idx) => Seq(idx * 10, idx * 10 + 1) } - assert(boundedResults.map(_._2) === expectedValues) + assert(boundedResults.map(_._3) === expectedValues) + assert(boundedResults.forall(_._1 == "key1")) + + // Exact bound: startKey is inclusive, endKey is exclusive. + // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. + val exactIter = store.scanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, 9L)), + Some(keyAndTimestampToRow("key1", 1, 90L))) + val exactResults = exactIter.map(_.key.getLong(2)).toList + exactIter.close() + val exactResultsDistinct = exactResults.distinct + assert(exactResultsDistinct === diverseTimestamps + .filter(ts => ts >= 9 && ts < 90).sorted) + assert(exactResultsDistinct.contains(9L)) + assert(!exactResultsDistinct.contains(90L)) + + // Postfix timestamp encoder places the timestamp after the key prefix. + // With different key prefixes, None in startKey or endKey would scan across + // key boundaries, which is not meaningful for postfix encoding. Hence we only + // test bounded ranges with explicit keys here. // Full range [MinValue, MaxValue) val fullIter = store.scanWithMultiValues( @@ -612,6 +631,15 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession assert(fullResults.distinct === diverseTimestamps.sorted) + // Bounded negative range [-300, 0) + val negIter = store.scanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, -300L)), + Some(keyAndTimestampToRow("key1", 1, 0L))) + val negResults = negIter.map(_.key.getLong(2)).toList + negIter.close() + assert(negResults.distinct === diverseTimestamps + .filter(ts => ts >= -300 && ts < 0).sorted) + // Empty range [10, 31) - no diverseTimestamps entries between 9 and 32 val emptyIter = store.scanWithMultiValues( Some(keyAndTimestampToRow("key1", 1, 10L)), @@ -624,6 +652,9 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession } } + // Sanity test for prefix encoder scan. Full scan coverage is in RocksDBStateStoreSuite's + // "rocksdb range scan - scan" and "rocksdb range scan - scanWithMultiValues" tests. + // This test verifies the timestamp prefix encoder integration works correctly. test(s"scan with prefix encoder (encoding = $encoding)") { tryWithProviderResource( newStoreProviderWithTimestampEncoder( @@ -643,10 +674,13 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // None startKey scans from beginning up to 301 (exclusive) val iter1 = store.scanWithMultiValues(None, Some(keyAndTimestampToRow("key1", 1, 301L))) - val results1 = iter1.map(_.key.getLong(2)).toList + val results1 = iter1.map { pair => + (pair.key.getString(0), pair.key.getLong(2)) + }.toList iter1.close() - assert(results1 === Seq(100L, 150L, 200L, 300L)) + assert(results1 === Seq( + ("key1", 100L), ("key2", 150L), ("key1", 200L), ("key1", 300L))) // Boundary safety: endKey at 201, includes everything up to 200 // regardless of join key @@ -657,54 +691,14 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession }.toList iter2.close() - assert(results2.map(_._2) === Seq(100L, 150L, 200L)) + assert(results2 === Seq( + ("key1", 100L), ("key2", 150L), ("key1", 200L))) } finally { store.abort() } } } - test(s"scan single-value variant (encoding = $encoding)") { - tryWithProviderResource( - newStoreProviderWithTimestampEncoder( - encoderType = "postfix", - useMultipleValuesPerKey = false, - dataEncoding = encoding) - ) { provider => - val store = provider.getStore(0) - - try { - diverseTimestamps.foreach { ts => - store.put(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) - } - - // Bounded positive range [0, 100) - val posIter = store.scan( - Some(keyAndTimestampToRow("key1", 1, 0L)), - Some(keyAndTimestampToRow("key1", 1, 100L))) - val posResults = posIter.map { pair => - (pair.key.getLong(2), pair.value.getInt(0)) - }.toList - posIter.close() - - val expectedPosTs = diverseTimestamps.filter(ts => ts >= 0 && ts < 100).sorted - assert(posResults.map(_._1) === expectedPosTs) - assert(posResults.map(_._2) === expectedPosTs.map(_.toInt)) - - // Bounded negative range [-300, 0) - val negIter = store.scan( - Some(keyAndTimestampToRow("key1", 1, -300L)), - Some(keyAndTimestampToRow("key1", 1, 0L))) - val negResults = negIter.map(_.key.getLong(2)).toList - negIter.close() - - val expectedNegTs = diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted - assert(negResults === expectedNegTs) - } finally { - store.abort() - } - } - } } // Helper methods to create test data From 710466b9d4c8169c8cffb6338430656e82e5696c Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 7 Apr 2026 15:15:07 +0900 Subject: [PATCH 07/12] Change the API name to rangeScan --- .../streaming/state/RocksDBStateEncoder.scala | 11 +++++++ .../state/RocksDBStateStoreProvider.scala | 13 +++++--- .../streaming/state/StateStore.scala | 16 +++++----- ...sDBStateStoreCheckpointFormatV2Suite.scala | 8 ++--- .../state/RocksDBStateStoreSuite.scala | 30 +++++++++---------- ...cksDBTimestampEncoderOperationsSuite.scala | 20 ++++++------- 6 files changed, 57 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 5acf7cdc9b975..4d9c77348b493 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -46,6 +46,7 @@ import org.apache.spark.unsafe.Platform sealed trait RocksDBKeyStateEncoder { def supportPrefixKeyScan: Boolean def supportsDeleteRange: Boolean + def supportsRangeScan: Boolean def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] def encodeKey(row: UnsafeRow): Array[Byte] def decodeKey(keyBytes: Array[Byte]): UnsafeRow @@ -1500,6 +1501,8 @@ class PrefixKeyScanStateEncoder( override def supportPrefixKeyScan: Boolean = true override def supportsDeleteRange: Boolean = false + + override def supportsRangeScan: Boolean = false } /** @@ -1699,6 +1702,8 @@ class RangeKeyScanStateEncoder( override def supportPrefixKeyScan: Boolean = true override def supportsDeleteRange: Boolean = true + + override def supportsRangeScan: Boolean = true } /** @@ -1731,6 +1736,8 @@ class NoPrefixKeyStateEncoder( override def supportsDeleteRange: Boolean = false + override def supportsRangeScan: Boolean = false + override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { throw new IllegalStateException("This encoder doesn't support prefix key!") } @@ -1884,6 +1891,8 @@ class TimestampAsPrefixKeyStateEncoder( // TODO: [SPARK-55491] Revisit this to support delete range if needed. override def supportsDeleteRange: Boolean = false + + override def supportsRangeScan: Boolean = true } /** @@ -1932,6 +1941,8 @@ class TimestampAsPostfixKeyStateEncoder( } override def supportsDeleteRange: Boolean = false + + override def supportsRangeScan: Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 7629ffe7d4325..e1490c71bc69b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -549,14 +549,17 @@ private[sql] class RocksDBStateStoreProvider new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) } - override def scan( + override def rangeScan( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { validateAndTransitionState(UPDATE) - verifyColFamilyOperations("scan", colFamilyName) + verifyColFamilyOperations("rangeScan", colFamilyName) val kvEncoder = keyValueEncoderMap.get(colFamilyName) + require(kvEncoder._1.supportsRangeScan, + "Range scan requires an encoder that supports range scanning!") + val encodedStartKey = startKey.map(kvEncoder._1.encodeKey) val encodedEndKey = endKey.map(kvEncoder._1.encodeKey) @@ -571,14 +574,16 @@ private[sql] class RocksDBStateStoreProvider new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) } - override def scanWithMultiValues( + override def rangeScanWithMultiValues( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { validateAndTransitionState(UPDATE) - verifyColFamilyOperations("scanWithMultiValues", colFamilyName) + verifyColFamilyOperations("rangeScanWithMultiValues", colFamilyName) val kvEncoder = keyValueEncoderMap.get(colFamilyName) + require(kvEncoder._1.supportsRangeScan, + "Range scan requires an encoder that supports range scanning!") verify( kvEncoder._2.supportsMultipleValuesPerKey, "Multi-value iterator operation requires an encoder" + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 9c4e2f06587a2..e3601f1ef2246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -196,12 +196,12 @@ trait ReadStateStore { * bytes for the scan range to be meaningful (e.g., timestamp-based encoders or * RangeKeyScanStateEncoder). */ - def scan( + def rangeScan( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) : StateStoreIterator[UnsafeRowPair] = { - throw StateStoreErrors.unsupportedOperationException("scan", "") + throw StateStoreErrors.unsupportedOperationException("rangeScan", "") } /** @@ -220,12 +220,12 @@ trait ReadStateStore { * It is expected to throw exception if Spark calls this method without setting * multipleValuesPerKey as true for the column family. */ - def scanWithMultiValues( + def rangeScanWithMultiValues( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) : StateStoreIterator[UnsafeRowPair] = { - throw StateStoreErrors.unsupportedOperationException("scanWithMultiValues", "") + throw StateStoreErrors.unsupportedOperationException("rangeScanWithMultiValues", "") } /** Return an iterator containing all the key-value pairs in the StateStore. */ @@ -456,18 +456,18 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore { store.prefixScanWithMultiValues(prefixKey, colFamilyName) } - override def scan( + override def rangeScan( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { - store.scan(startKey, endKey, colFamilyName) + store.rangeScan(startKey, endKey, colFamilyName) } - override def scanWithMultiValues( + override def rangeScanWithMultiValues( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { - store.scanWithMultiValues(startKey, endKey, colFamilyName) + store.rangeScanWithMultiValues(startKey, endKey, colFamilyName) } override def iteratorWithMultiValues( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index dca9677a6f7bc..fe09506023ccd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -172,18 +172,18 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta innerStore.prefixScanWithMultiValues(prefixKey, colFamilyName) } - override def scan( + override def rangeScan( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { - innerStore.scan(startKey, endKey, colFamilyName) + innerStore.rangeScan(startKey, endKey, colFamilyName) } - override def scanWithMultiValues( + override def rangeScanWithMultiValues( startKey: Option[UnsafeRow], endKey: Option[UnsafeRow], colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { - innerStore.scanWithMultiValues(startKey, endKey, colFamilyName) + innerStore.rangeScanWithMultiValues(startKey, endKey, colFamilyName) } override def iteratorWithMultiValues( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index e6dd70c55d737..0c300192dd898 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -1633,7 +1633,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid -230L, -14569L, -92L, -7434253L, 35L, 6L, 9L, -323L, 5L, -32L, -64L, -256L, 64L, 32L, 1024L, 4096L, 0L) - testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - scan", + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - rangeScan", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, @@ -1653,7 +1653,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } // Bounded positive range [0, 100) - val boundedIter = store.scan( + val boundedIter = store.rangeScan( Some(dataToKeyRowWithRangeScan(0L, "a")), Some(dataToKeyRowWithRangeScan(100L, "a")), cfName) val boundedResults = boundedIter.map { pair => @@ -1667,7 +1667,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid // Exact bound: startKey is inclusive, endKey is exclusive. // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. // Scan [9, 90) should include 9 but exclude 90. - val exactIter = store.scan( + val exactIter = store.rangeScan( Some(dataToKeyRowWithRangeScan(9L, "a")), Some(dataToKeyRowWithRangeScan(90L, "a")), cfName) val exactResults = exactIter.map(_.key.getLong(0)).toList @@ -1677,28 +1677,28 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid assert(!exactResults.contains(90L)) // None startKey scans from beginning to 0 - val noneStartIter = store.scan( + val noneStartIter = store.rangeScan( None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) val noneStartResults = noneStartIter.map(_.key.getLong(0)).toList noneStartIter.close() assert(noneStartResults === diverseTimestamps.filter(_ < 0).sorted) // None endKey scans from 1000 to end - val noneEndIter = store.scan( + val noneEndIter = store.rangeScan( Some(dataToKeyRowWithRangeScan(1000L, "a")), None, cfName) val noneEndResults = noneEndIter.map(_.key.getLong(0)).toList noneEndIter.close() assert(noneEndResults === diverseTimestamps.filter(_ >= 1000).sorted) // Empty range [10, 31) - no entries between 9 and 32 - val emptyIter = store.scan( + val emptyIter = store.rangeScan( Some(dataToKeyRowWithRangeScan(10L, "a")), Some(dataToKeyRowWithRangeScan(31L, "a")), cfName) assert(!emptyIter.hasNext) emptyIter.close() // Bounded negative range [-300, 0) - val negIter = store.scan( + val negIter = store.rangeScan( Some(dataToKeyRowWithRangeScan(-300L, "a")), Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) val negResults = negIter.map(_.key.getLong(0)).toList @@ -1734,7 +1734,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid val startKey = dataToKeyRowWithRangeScan(100L, "a") val endKey = dataToKeyRowWithRangeScan(201L, "a") - val iter = store.scan(Some(startKey), Some(endKey), cfName) + val iter = store.rangeScan(Some(startKey), Some(endKey), cfName) val results = iter.map { pair => (pair.key.getLong(0), pair.key.getUTF8String(1).toString) }.toList @@ -1751,7 +1751,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } testWithColumnFamiliesAndEncodingTypes( - "rocksdb range scan - scanWithMultiValues", + "rocksdb range scan - rangeScanWithMultiValues", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => if (colFamiliesEnabled) { @@ -1775,7 +1775,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } // Bounded range [0, 1001) - val boundedIter = store.scanWithMultiValues( + val boundedIter = store.rangeScanWithMultiValues( Some(dataToKeyRowWithRangeScan(0L, "a")), Some(dataToKeyRowWithRangeScan(1001L, "a")), cfName) val boundedResults = boundedIter.map { pair => @@ -1793,7 +1793,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid // Exact bound: startKey is inclusive, endKey is exclusive. // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. - val exactIter = store.scanWithMultiValues( + val exactIter = store.rangeScanWithMultiValues( Some(dataToKeyRowWithRangeScan(9L, "a")), Some(dataToKeyRowWithRangeScan(90L, "a")), cfName) val exactResults = exactIter.map(_.key.getLong(0)).toList @@ -1805,28 +1805,28 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid assert(!exactResultsDistinct.contains(90L)) // None startKey scans from beginning to 0 - val noneStartIter = store.scanWithMultiValues( + val noneStartIter = store.rangeScanWithMultiValues( None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) val noneStartResults = noneStartIter.map(_.key.getLong(0)).toList noneStartIter.close() assert(noneStartResults.distinct === diverseTimestamps.filter(_ < 0).sorted) // None endKey scans from 1000 to end - val noneEndIter = store.scanWithMultiValues( + val noneEndIter = store.rangeScanWithMultiValues( Some(dataToKeyRowWithRangeScan(1000L, "a")), None, cfName) val noneEndResults = noneEndIter.map(_.key.getLong(0)).toList noneEndIter.close() assert(noneEndResults.distinct === diverseTimestamps.filter(_ >= 1000).sorted) // Empty range [10, 31) - no entries between 9 and 32 - val emptyIter = store.scanWithMultiValues( + val emptyIter = store.rangeScanWithMultiValues( Some(dataToKeyRowWithRangeScan(10L, "a")), Some(dataToKeyRowWithRangeScan(31L, "a")), cfName) assert(!emptyIter.hasNext) emptyIter.close() // Bounded negative range [-300, 0) - val negIter = store.scanWithMultiValues( + val negIter = store.rangeScanWithMultiValues( Some(dataToKeyRowWithRangeScan(-300L, "a")), Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) val negResults = negIter.map(_.key.getLong(0)).toList diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala index 414a9a39872a0..5fcdfb12ba354 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala @@ -567,7 +567,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession } Seq("unsaferow", "avro").foreach { encoding => - test(s"scan with postfix encoder (encoding = $encoding)") { + test(s"rangeScan with postfix encoder (encoding = $encoding)") { tryWithProviderResource( newStoreProviderWithTimestampEncoder( encoderType = "postfix", @@ -587,7 +587,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession Array(valueToRow(999))) // Bounded range [0, 1001) - val boundedIter = store.scanWithMultiValues( + val boundedIter = store.rangeScanWithMultiValues( Some(keyAndTimestampToRow("key1", 1, 0L)), Some(keyAndTimestampToRow("key1", 1, 1001L))) val boundedResults = boundedIter.map { pair => @@ -606,7 +606,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // Exact bound: startKey is inclusive, endKey is exclusive. // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. - val exactIter = store.scanWithMultiValues( + val exactIter = store.rangeScanWithMultiValues( Some(keyAndTimestampToRow("key1", 1, 9L)), Some(keyAndTimestampToRow("key1", 1, 90L))) val exactResults = exactIter.map(_.key.getLong(2)).toList @@ -623,7 +623,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // test bounded ranges with explicit keys here. // Full range [MinValue, MaxValue) - val fullIter = store.scanWithMultiValues( + val fullIter = store.rangeScanWithMultiValues( Some(keyAndTimestampToRow("key1", 1, Long.MinValue)), Some(keyAndTimestampToRow("key1", 1, Long.MaxValue))) val fullResults = fullIter.map(_.key.getLong(2)).toList @@ -632,7 +632,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession assert(fullResults.distinct === diverseTimestamps.sorted) // Bounded negative range [-300, 0) - val negIter = store.scanWithMultiValues( + val negIter = store.rangeScanWithMultiValues( Some(keyAndTimestampToRow("key1", 1, -300L)), Some(keyAndTimestampToRow("key1", 1, 0L))) val negResults = negIter.map(_.key.getLong(2)).toList @@ -641,7 +641,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession .filter(ts => ts >= -300 && ts < 0).sorted) // Empty range [10, 31) - no diverseTimestamps entries between 9 and 32 - val emptyIter = store.scanWithMultiValues( + val emptyIter = store.rangeScanWithMultiValues( Some(keyAndTimestampToRow("key1", 1, 10L)), Some(keyAndTimestampToRow("key1", 1, 31L))) assert(!emptyIter.hasNext) @@ -653,9 +653,9 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession } // Sanity test for prefix encoder scan. Full scan coverage is in RocksDBStateStoreSuite's - // "rocksdb range scan - scan" and "rocksdb range scan - scanWithMultiValues" tests. + // "rocksdb range scan - rangeScan" and "rocksdb range scan - rangeScanWithMultiValues" tests. // This test verifies the timestamp prefix encoder integration works correctly. - test(s"scan with prefix encoder (encoding = $encoding)") { + test(s"rangeScan with prefix encoder (encoding = $encoding)") { tryWithProviderResource( newStoreProviderWithTimestampEncoder( encoderType = "prefix", @@ -672,7 +672,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession store.merge(keyAndTimestampToRow("key2", 2, 150L), valueToRow(150)) // None startKey scans from beginning up to 301 (exclusive) - val iter1 = store.scanWithMultiValues(None, + val iter1 = store.rangeScanWithMultiValues(None, Some(keyAndTimestampToRow("key1", 1, 301L))) val results1 = iter1.map { pair => (pair.key.getString(0), pair.key.getLong(2)) @@ -684,7 +684,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession // Boundary safety: endKey at 201, includes everything up to 200 // regardless of join key - val iter2 = store.scanWithMultiValues(None, + val iter2 = store.rangeScanWithMultiValues(None, Some(keyAndTimestampToRow("key1", 1, 201L))) val results2 = iter2.map { pair => (pair.key.getString(0), pair.key.getLong(2)) From 4369d7858d5fccdb93861248fb299984ea7460f1 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 7 Apr 2026 11:00:39 +0900 Subject: [PATCH 08/12] Apply scan/scanWithMultiValues to stream-stream join V4 Use bounded scan ranges in stream-stream join V4 operators to narrow the iteration scope during eviction and value lookup: - scanEvictedKeys (TsWithKeyTypeStore): use scanWithMultiValues with startKey derived from the previous batch's state watermark and endKey from the current eviction threshold. Thread prevBatchStateWatermark through JoinStateWatermarkPredicate -> SupportsEvictByTimestamp. - getValuesInRange (KeyWithTsToValuesStore): use scanWithMultiValues for bounded timestamp ranges, falling back to prefixScan for full range. Create default-valued boundary rows to avoid NullPointerException when the join key schema contains non-nullable fields (e.g. window structs). --- .../join/StreamingSymmetricHashJoinExec.scala | 28 ++-- .../StreamingSymmetricHashJoinHelper.scala | 25 +++- .../join/SymmetricHashJoinStateManager.scala | 125 +++++++++++++++--- .../runtime/IncrementalExecution.scala | 8 +- .../SymmetricHashJoinStateManagerSuite.scala | 61 +++++++++ .../sql/streaming/StreamingJoinV4Suite.scala | 69 +++++++++- 6 files changed, 281 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala index 1c50e6802c323..2b15c43bd3db6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala @@ -663,7 +663,7 @@ case class StreamingSymmetricHashJoinExec( private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes) private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match { - case Some(JoinStateKeyWatermarkPredicate(expr, _)) => + case Some(JoinStateKeyWatermarkPredicate(expr, _, _)) => // inputSchema can be empty as expr should only have BoundReferences and does not require // the schema to generated predicate. See [[StreamingSymmetricHashJoinHelper]]. Predicate.create(expr, Seq.empty).eval _ @@ -672,7 +672,7 @@ case class StreamingSymmetricHashJoinExec( } private[this] val stateValueWatermarkPredicateFunc = stateWatermarkPredicate match { - case Some(JoinStateValueWatermarkPredicate(expr, _)) => + case Some(JoinStateValueWatermarkPredicate(expr, _, _)) => Predicate.create(expr, inputAttributes).eval _ case _ => Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate @@ -893,21 +893,25 @@ case class StreamingSymmetricHashJoinExec( */ def removeOldState(): Long = { stateWatermarkPredicate match { - case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) => + case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) => joinStateManager match { case s: SupportsEvictByCondition => s.evictByKeyCondition(stateKeyWatermarkPredicateFunc) case s: SupportsEvictByTimestamp => - s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark)) + s.evictByTimestamp( + watermarkMsToStateTimestamp(stateWatermark), + prevStateWatermark.map(watermarkMsToStateTimestamp)) } - case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) => + case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) => joinStateManager match { case s: SupportsEvictByCondition => s.evictByValueCondition(stateValueWatermarkPredicateFunc) case s: SupportsEvictByTimestamp => - s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark)) + s.evictByTimestamp( + watermarkMsToStateTimestamp(stateWatermark), + prevStateWatermark.map(watermarkMsToStateTimestamp)) } case _ => 0L } @@ -925,21 +929,25 @@ case class StreamingSymmetricHashJoinExec( */ def removeAndReturnOldState(): Iterator[KeyToValuePair] = { stateWatermarkPredicate match { - case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) => + case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) => joinStateManager match { case s: SupportsEvictByCondition => s.evictAndReturnByKeyCondition(stateKeyWatermarkPredicateFunc) case s: SupportsEvictByTimestamp => - s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark)) + s.evictAndReturnByTimestamp( + watermarkMsToStateTimestamp(stateWatermark), + prevStateWatermark.map(watermarkMsToStateTimestamp)) } - case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) => + case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) => joinStateManager match { case s: SupportsEvictByCondition => s.evictAndReturnByValueCondition(stateValueWatermarkPredicateFunc) case s: SupportsEvictByTimestamp => - s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark)) + s.evictAndReturnByTimestamp( + watermarkMsToStateTimestamp(stateWatermark), + prevStateWatermark.map(watermarkMsToStateTimestamp)) } case _ => Iterator.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala index cea6398f4e501..80a299b4e6bb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala @@ -46,12 +46,18 @@ object StreamingSymmetricHashJoinHelper extends Logging { override def toString: String = s"$desc: $expr" } /** Predicate for watermark on state keys */ - case class JoinStateKeyWatermarkPredicate(expr: Expression, stateWatermark: Long) + case class JoinStateKeyWatermarkPredicate( + expr: Expression, + stateWatermark: Long, + prevStateWatermark: Option[Long] = None) extends JoinStateWatermarkPredicate { def desc: String = "key predicate" } /** Predicate for watermark on state values */ - case class JoinStateValueWatermarkPredicate(expr: Expression, stateWatermark: Long) + case class JoinStateValueWatermarkPredicate( + expr: Expression, + stateWatermark: Long, + prevStateWatermark: Option[Long] = None) extends JoinStateWatermarkPredicate { def desc: String = "value predicate" } @@ -185,6 +191,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { rightKeys: Seq[Expression], condition: Option[Expression], eventTimeWatermarkForEviction: Option[Long], + eventTimeWatermarkForLateEvents: Option[Long], useFirstEventTimeColumn: Boolean): JoinStateWatermarkPredicates = { // Perform assertions against multiple event time columns in the same DataFrame. This method @@ -215,7 +222,10 @@ object StreamingSymmetricHashJoinHelper extends Logging { expr.map { e => // watermarkExpression only provides the expression when eventTimeWatermarkForEviction // is defined - JoinStateKeyWatermarkPredicate(e, eventTimeWatermarkForEviction.get) + JoinStateKeyWatermarkPredicate( + e, + eventTimeWatermarkForEviction.get, + eventTimeWatermarkForLateEvents) } } else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark( @@ -223,12 +233,19 @@ object StreamingSymmetricHashJoinHelper extends Logging { attributesWithEventWatermark = AttributeSet(otherSideInputAttributes), condition, eventTimeWatermarkForEviction) + val prevStateValueWatermark = eventTimeWatermarkForLateEvents.flatMap { _ => + StreamingJoinHelper.getStateValueWatermark( + attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes), + attributesWithEventWatermark = AttributeSet(otherSideInputAttributes), + condition, + eventTimeWatermarkForLateEvents) + } val inputAttributeWithWatermark = oneSideInputAttributes.find(_.metadata.contains(delayKey)) val expr = watermarkExpression(inputAttributeWithWatermark, stateValueWatermark) expr.map { e => // watermarkExpression only provides the expression when eventTimeWatermarkForEviction // is defined - JoinStateValueWatermarkPredicate(e, stateValueWatermark.get) + JoinStateValueWatermarkPredicate(e, stateValueWatermark.get, prevStateValueWatermark) } } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index fc2a69312fe79..611f548f44b88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -34,7 +34,8 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOper import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay, TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec, TimestampKeyStateEncoder} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.NextIterator /** @@ -184,15 +185,28 @@ trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager => trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager => import SymmetricHashJoinStateManager._ - /** Evict the state by timestamp. Returns the number of values evicted. */ - def evictByTimestamp(endTimestamp: Long): Long + /** + * Evict the state by timestamp. Returns the number of values evicted. + * + * @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp. + * @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are + * assumed to have been evicted already (e.g. from the previous batch). When provided, + * the scan starts from startTimestamp + 1. + */ + def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long /** * Evict the state by timestamp and return the evicted key-value pairs. * * It is caller's responsibility to consume the whole iterator. + * + * @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp. + * @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are + * assumed to have been evicted already (e.g. from the previous batch). When provided, + * the scan starts from startTimestamp + 1. */ - def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair] + def evictAndReturnByTimestamp( + endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair] } /** @@ -507,9 +521,9 @@ class SymmetricHashJoinStateManagerV4( } } - override def evictByTimestamp(endTimestamp: Long): Long = { + override def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long = { var removed = 0L - tsWithKey.scanEvictedKeys(endTimestamp).foreach { evicted => + tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).foreach { evicted => val key = evicted.key val timestamp = evicted.timestamp val numValues = evicted.numValues @@ -523,10 +537,11 @@ class SymmetricHashJoinStateManagerV4( removed } - override def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair] = { + override def evictAndReturnByTimestamp( + endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair] = { val reusableKeyToValuePair = KeyToValuePair() - tsWithKey.scanEvictedKeys(endTimestamp).flatMap { evicted => + tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).flatMap { evicted => val key = evicted.key val timestamp = evicted.timestamp val values = keyWithTsToValues.get(key, timestamp) @@ -647,17 +662,30 @@ class SymmetricHashJoinStateManagerV4( /** * Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp. - * Skips entries before minTs and stops iterating past maxTs (timestamps are sorted). + * When maxTs is bounded (< Long.MaxValue), uses rangeScanWithMultiValues for efficient + * range access; falls back to prefixScan otherwise to stay within the key's scope. + * + * When prefixScan is used (maxTs == Long.MaxValue), entries outside [minTs, maxTs] are + * filtered out so both code paths produce identical results. */ def getValuesInRange( key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = { val reusableGetValuesResult = new GetValuesResult() + // Only use rangeScan when maxTs < Long.MaxValue, since rangeScan requires + // an exclusive end key (maxTs + 1) which would overflow at Long.MaxValue. + val useRangeScan = maxTs < Long.MaxValue new NextIterator[GetValuesResult] { - private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName) + private val iter = if (useRangeScan) { + val startKey = createKeyRow(key, minTs).copy() + // rangeScanWithMultiValues endKey is exclusive, so use maxTs + 1 + val endKey = Some(createKeyRow(key, maxTs + 1)) + stateStore.rangeScanWithMultiValues(Some(startKey), endKey, colFamilyName) + } else { + stateStore.prefixScanWithMultiValues(key, colFamilyName) + } private var currentTs = -1L - private var pastUpperBound = false private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]() private def flushAccumulated(): GetValuesResult = { @@ -675,16 +703,16 @@ class SymmetricHashJoinStateManagerV4( @tailrec override protected def getNext(): GetValuesResult = { - if (pastUpperBound || !iter.hasNext) { + if (!iter.hasNext) { flushAccumulated() } else { val unsafeRowPair = iter.next() val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key) - if (ts > maxTs) { - pastUpperBound = true - getNext() - } else if (ts < minTs) { + // Filter out entries outside [minTs, maxTs]. This is essential when using + // prefixScan (which returns all timestamps for the key) and serves as a + // safety guard for rangeScan as well. + if (ts < minTs || ts > maxTs) { getNext() } else if (currentTs == -1L || currentTs == ts) { currentTs = ts @@ -757,6 +785,8 @@ class SymmetricHashJoinStateManagerV4( isInternal = true ) + // Returns an UnsafeRow backed by a reused projection buffer. Callers that need to + // hold the row beyond the immediate state store call must invoke copy() on the result. private def createKeyRow(key: UnsafeRow, timestamp: Long): UnsafeRow = { TimestampKeyStateEncoder.attachTimestamp( attachTimestampProjection, keySchemaWithTimestamp, key, timestamp) @@ -772,9 +802,66 @@ class SymmetricHashJoinStateManagerV4( case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int) - // NOTE: This assumes we consume the whole iterator to trigger completion. - def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = { - val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName) + private def defaultInternalRow(schema: StructType): InternalRow = { + InternalRow.fromSeq(schema.map(f => defaultValueForType(f.dataType))) + } + + private def defaultValueForType(dt: DataType): Any = dt match { + case BooleanType => false + case ByteType => 0.toByte + case ShortType => 0.toShort + case IntegerType | DateType => 0 + case LongType | TimestampType | TimestampNTZType => 0L + case FloatType => 0.0f + case DoubleType => 0.0 + case StringType => UTF8String.EMPTY_UTF8 + case BinaryType => Array.emptyByteArray + case st: StructType => defaultInternalRow(st) + case _ => null + } + + /** + * Build a scan boundary row for rangeScan. The TsWithKeyTypeStore uses + * TimestampAsPrefixKeyStateEncoder, which encodes the row as [timestamp][key_fields]. + * We need a full-schema row (not just the timestamp) because the encoder expects all + * key columns to be present. Default values are used for the key fields since only the + * timestamp matters for ordering in the prefix encoder. + */ + private def createScanBoundaryRow(timestamp: Long): UnsafeRow = { + val defaultKey = UnsafeProjection.create(keySchema) + .apply(defaultInternalRow(keySchema)) + createKeyRow(defaultKey, timestamp).copy() + } + + /** + * Scan keys eligible for eviction within the timestamp range. + * + * This assumes we consume the whole iterator to trigger completion. + * + * @param endTimestamp Inclusive upper bound: entries with timestamp <= endTimestamp are + * eligible for eviction. + * @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp + * are assumed to have been evicted already. The scan starts from startTimestamp + 1. + */ + def scanEvictedKeys( + endTimestamp: Long, + startTimestamp: Option[Long] = None): Iterator[EvictedKeysResult] = { + // rangeScanWithMultiValues: startKey is inclusive, endKey is exclusive. + // startTimestamp is exclusive (already evicted), so we seek from st + 1. + val startKeyRow = startTimestamp.flatMap { st => + if (st < Long.MaxValue) Some(createScanBoundaryRow(st + 1)) + else None + } + // endTimestamp is inclusive, so we use endTimestamp + 1 as the exclusive upper bound. + // When endTimestamp == Long.MaxValue we cannot add 1, so endKeyRow is None. This is + // safe because rangeScanWithMultiValues with no end key uses the column-family prefix + // as the upper bound, naturally scoping the scan within this column family. + val endKeyRow = if (endTimestamp < Long.MaxValue) { + Some(createScanBoundaryRow(endTimestamp + 1)) + } else { + None + } + val evictIterator = stateStore.rangeScanWithMultiValues(startKeyRow, endKeyRow, colFamilyName) new NextIterator[EvictedKeysResult]() { var currentKeyRow: UnsafeRow = null var currentEventTime: Long = -1L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala index 169ab6f606dae..7e08c24e452f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala @@ -528,13 +528,19 @@ class IncrementalExecution( case j: StreamingSymmetricHashJoinExec => val iwLateEvents = inputWatermarkForLateEvents(j.stateInfo.get) val iwEviction = inputWatermarkForEviction(j.stateInfo.get) + // Only use the late-events watermark as the scan lower bound when a previous + // batch actually existed. In the very first batch the watermark propagation + // yields Some(0) even though no state has been evicted yet, which would + // incorrectly skip entries at timestamp 0. + val prevBatchLateEventsWm = + if (prevOffsetSeqMetadata.isDefined) iwLateEvents else None j.copy( eventTimeWatermarkForLateEvents = iwLateEvents, eventTimeWatermarkForEviction = iwEviction, stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, - iwEviction, !allowMultipleStatefulOperators) + iwEviction, prevBatchLateEventsWm, !allowMultipleStatefulOperators) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index 1042f01463b05..ae7dce78151a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -1105,6 +1105,67 @@ class SymmetricHashJoinStateManagerEventTimeInValueSuite } } + test("StreamingJoinStateManager V4 - getValuesInRange boundary edge cases") { + withJoinStateManager( + inputValueAttributes, joinKeyExpressions, stateFormatVersion = 4) { manager => + implicit val mgr = manager + + Seq(10, 20, 30, 40, 50).foreach(append(40, _)) + + // Exact boundary matches (both inclusive) + assert(getJoinedRowTimestamps(40, Some((10L, 10L))) === Seq(10)) + assert(getJoinedRowTimestamps(40, Some((50L, 50L))) === Seq(50)) + + // Range with Long.MinValue / Long.MaxValue + assert(getJoinedRowTimestamps(40, Some((Long.MinValue, 30L))) === Seq(10, 20, 30)) + assert(getJoinedRowTimestamps(40, Some((30L, Long.MaxValue))) === Seq(30, 40, 50)) + assert(getJoinedRowTimestamps(40, Some((Long.MinValue, Long.MaxValue))) === + Seq(10, 20, 30, 40, 50)) + + // Empty range (minTs > maxTs) + assert(getJoinedRowTimestamps(40, Some((50L, 10L))) === Seq.empty) + + // Range entirely outside stored timestamps + assert(getJoinedRowTimestamps(40, Some((100L, 200L))) === Seq.empty) + assert(getJoinedRowTimestamps(40, Some((1L, 5L))) === Seq.empty) + + // Full range via None (all entries) + assert(getJoinedRowTimestamps(40, None) === Seq(10, 20, 30, 40, 50)) + } + } + + test("StreamingJoinStateManager V4 - evictByTimestamp boundary edge cases") { + withJoinStateManager( + inputValueAttributes, joinKeyExpressions, stateFormatVersion = 4) { manager => + implicit val mgr = manager + val evictByTs = manager.asInstanceOf[SupportsEvictByTimestamp] + + // --- Range eviction with startTimestamp (exclusive) and endTimestamp (inclusive) --- + Seq(10, 20, 30, 40, 50).foreach(append(40, _)) + // startTimestamp=20 is exclusive, endTimestamp=40 is inclusive: evicts timestamps 30, 40 + assert(evictByTs.evictByTimestamp(40, Some(20)) === 2) + assert(get(40) === Seq(10, 20, 50)) + + // --- evictAndReturnByTimestamp returns evicted values --- + Seq(30, 40).foreach(append(40, _)) // restore evicted entries + val evictedValues = evictByTs.evictAndReturnByTimestamp(30, Some(10)) + .map(p => toValueInt(p.value)).toSeq.sorted + // startTimestamp=10 is exclusive, endTimestamp=30 is inclusive: timestamps 20 and 30 + assert(evictedValues === Seq(20, 30)) + assert(get(40) === Seq(10, 40, 50)) + + // --- start equals end: empty range (exclusive start = inclusive end) --- + // startTimestamp=40 (exclusive) and endTimestamp=40 (inclusive): range is empty + assert(evictByTs.evictByTimestamp(40, Some(40)) === 0) + assert(get(40) === Seq(10, 40, 50)) + + // --- start just below entry: evicts exactly that entry --- + // startTimestamp=39 (exclusive) means entries >= 40 are scanned; endTimestamp=40 inclusive + assert(evictByTs.evictByTimestamp(40, Some(39)) === 1) + assert(get(40) === Seq(10, 50)) + } + } + // V1 excluded: V1 converter does not persist matched flags (SPARK-26154) versionsInTest.filter(_ >= 2).foreach { ver => test(s"StreamingJoinStateManager V$ver - skipUpdatingMatchedFlag skips matched flag update") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala index ef4615c1254f3..66406cc2afa7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala @@ -22,7 +22,8 @@ import org.scalatest.Tag import org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinExec -import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.{JoinStateKeyWatermarkPredicate, JoinStateValueWatermarkPredicate} +import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamExecution} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -184,6 +185,72 @@ class StreamingInnerJoinV4Suite ) } } + + test("prevStateWatermark must be None in the first batch") { + // Regression test for the IncrementalExecution guard: in the first batch + // prevOffsetSeqMetadata is None, so eventTimeWatermarkForLateEvents must NOT + // be passed to getStateWatermarkPredicates. Without the guard the watermark + // propagation framework yields Some(0) even in batch 0, which would cause + // scanEvictedKeys to skip state entries at timestamp 0. + val input1 = MemoryStream[(Int, Int)] + val input2 = MemoryStream[(Int, Int)] + + val df1 = input1.toDF().toDF("key", "time") + .select($"key", timestamp_seconds($"time") as "leftTime", + ($"key" * 2) as "leftValue") + .withWatermark("leftTime", "10 seconds") + val df2 = input2.toDF().toDF("key", "time") + .select($"key", timestamp_seconds($"time") as "rightTime", + ($"key" * 3) as "rightValue") + .withWatermark("rightTime", "10 seconds") + + val joined = df1.join(df2, + df1("key") === df2("key") && + expr("leftTime >= rightTime - interval 5 seconds " + + "AND leftTime <= rightTime + interval 5 seconds"), + "inner") + .select(df1("key"), $"leftTime".cast("long"), $"leftValue", $"rightValue") + + def extractPrevWatermarks(q: StreamExecution): (Option[Long], Option[Long]) = { + val joinExec = q.lastExecution.executedPlan.collect { + case j: StreamingSymmetricHashJoinExec => j + }.head + val leftPrev = joinExec.stateWatermarkPredicates.left.flatMap { + case p: JoinStateKeyWatermarkPredicate => p.prevStateWatermark + case p: JoinStateValueWatermarkPredicate => p.prevStateWatermark + } + val rightPrev = joinExec.stateWatermarkPredicates.right.flatMap { + case p: JoinStateKeyWatermarkPredicate => p.prevStateWatermark + case p: JoinStateValueWatermarkPredicate => p.prevStateWatermark + } + (leftPrev, rightPrev) + } + + testStream(joined)( + // First batch: prevStateWatermark must be None on both sides. + MultiAddData(input1, (1, 5))(input2, (1, 5)), + CheckNewAnswer((1, 5, 2, 3)), + Execute { q => + val (leftPrev, rightPrev) = extractPrevWatermarks(q) + assert(leftPrev.isEmpty, + s"Left prevStateWatermark should be None in the first batch, got $leftPrev") + assert(rightPrev.isEmpty, + s"Right prevStateWatermark should be None in the first batch, got $rightPrev") + }, + + // Second batch: after watermark advances, prevStateWatermark should be set. + MultiAddData(input1, (2, 30))(input2, (2, 30)), + CheckNewAnswer((2, 30, 4, 6)), + Execute { q => + val (leftPrev, rightPrev) = extractPrevWatermarks(q) + assert(leftPrev.isDefined, + "Left prevStateWatermark should be defined after the first batch") + assert(rightPrev.isDefined, + "Right prevStateWatermark should be defined after the first batch") + }, + StopStream + ) + } } @SlowSQLTest From 44d5a4e833eb75456216ac8e2fec254bde5d7e72 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 14 Apr 2026 10:09:26 +0900 Subject: [PATCH 09/12] reflect review comments --- .../join/SymmetricHashJoinStateManager.scala | 38 ++++++++----------- .../SymmetricHashJoinStateManagerSuite.scala | 24 ++++++++++++ 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index 611f548f44b88..a6ffdfec6a74c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -34,8 +34,7 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOper import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay, TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec, TimestampKeyStateEncoder} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType} -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType, StructField, StructType} import org.apache.spark.util.NextIterator /** @@ -686,6 +685,7 @@ class SymmetricHashJoinStateManagerV4( } private var currentTs = -1L + private var pastUpperBound = false private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]() private def flushAccumulated(): GetValuesResult = { @@ -703,16 +703,16 @@ class SymmetricHashJoinStateManagerV4( @tailrec override protected def getNext(): GetValuesResult = { - if (!iter.hasNext) { + if (pastUpperBound || !iter.hasNext) { flushAccumulated() } else { val unsafeRowPair = iter.next() val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key) - // Filter out entries outside [minTs, maxTs]. This is essential when using - // prefixScan (which returns all timestamps for the key) and serves as a - // safety guard for rangeScan as well. - if (ts < minTs || ts > maxTs) { + if (ts > maxTs) { + pastUpperBound = true + getNext() + } else if (ts < minTs) { getNext() } else if (currentTs == -1L || currentTs == ts) { currentTs = ts @@ -803,22 +803,16 @@ class SymmetricHashJoinStateManagerV4( case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int) private def defaultInternalRow(schema: StructType): InternalRow = { - InternalRow.fromSeq(schema.map(f => defaultValueForType(f.dataType))) + InternalRow.fromSeq(schema.map(f => Literal.default(f.dataType).value)) } - private def defaultValueForType(dt: DataType): Any = dt match { - case BooleanType => false - case ByteType => 0.toByte - case ShortType => 0.toShort - case IntegerType | DateType => 0 - case LongType | TimestampType | TimestampNTZType => 0L - case FloatType => 0.0f - case DoubleType => 0.0 - case StringType => UTF8String.EMPTY_UTF8 - case BinaryType => Array.emptyByteArray - case st: StructType => defaultInternalRow(st) - case _ => null - } + /** + * Reusable default key row for scan boundary construction. Safe to reuse because + * createKeyRow only reads this row (via BoundReference evaluations) and writes to + * the projection's own internal buffer. + */ + private lazy val defaultKey: UnsafeRow = UnsafeProjection.create(keySchema) + .apply(defaultInternalRow(keySchema)) /** * Build a scan boundary row for rangeScan. The TsWithKeyTypeStore uses @@ -828,8 +822,6 @@ class SymmetricHashJoinStateManagerV4( * timestamp matters for ordering in the prefix encoder. */ private def createScanBoundaryRow(timestamp: Long): UnsafeRow = { - val defaultKey = UnsafeProjection.create(keySchema) - .apply(defaultInternalRow(keySchema)) createKeyRow(defaultKey, timestamp).copy() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index ae7dce78151a0..c9b81b12b735e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -1163,6 +1163,30 @@ class SymmetricHashJoinStateManagerEventTimeInValueSuite // startTimestamp=39 (exclusive) means entries >= 40 are scanned; endTimestamp=40 inclusive assert(evictByTs.evictByTimestamp(40, Some(39)) === 1) assert(get(40) === Seq(10, 50)) + + // --- overflow boundary: endTimestamp = Long.MaxValue --- + // Restore entries for a clean slate + Seq(20, 30, 40).foreach(append(40, _)) + // endTimestamp=Long.MaxValue with no startTimestamp: evicts all entries + assert(evictByTs.evictByTimestamp(Long.MaxValue) === 5) + assert(get(40) === Seq.empty) + + // --- overflow boundary: startTimestamp = Some(Long.MinValue) --- + Seq(10, 20, 30).foreach(append(40, _)) + // startTimestamp=Long.MinValue (exclusive), endTimestamp=20 (inclusive): + // Long.MinValue is excluded per the contract (already evicted), so the scan + // starts from Long.MinValue + 1. Since no real entry has timestamp Long.MinValue, + // this effectively scans all entries up to endTimestamp. + assert(evictByTs.evictByTimestamp(20, Some(Long.MinValue)) === 2) + assert(get(40) === Seq(30)) + + // --- overflow boundary: startTimestamp = Some(Long.MaxValue) --- + Seq(10, 20).foreach(append(40, _)) + // startTimestamp=Long.MaxValue (exclusive) means everything <= Long.MaxValue was already + // evicted. Since startKeyRow falls back to None, endTimestamp=50 bounds the scan. + // All remaining entries (10, 20, 30) have timestamps <= 50, so they are evicted. + assert(evictByTs.evictByTimestamp(50, Some(Long.MaxValue)) === 3) + assert(get(40) === Seq.empty) } } From 56596030385ca0b28c5174cd4a079e814adf75da Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 14 Apr 2026 12:33:16 +0900 Subject: [PATCH 10/12] fix --- .../stateful/join/SymmetricHashJoinStateManager.scala | 11 ++++++++--- .../state/SymmetricHashJoinStateManagerSuite.scala | 7 +++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index a6ffdfec6a74c..66f0ee5f80a56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -838,11 +838,16 @@ class SymmetricHashJoinStateManagerV4( def scanEvictedKeys( endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[EvictedKeysResult] = { + // If startTimestamp == Long.MaxValue, everything has already been evicted; + // nothing can match, so return immediately. + if (startTimestamp.contains(Long.MaxValue)) { + return Iterator.empty + } + // rangeScanWithMultiValues: startKey is inclusive, endKey is exclusive. // startTimestamp is exclusive (already evicted), so we seek from st + 1. - val startKeyRow = startTimestamp.flatMap { st => - if (st < Long.MaxValue) Some(createScanBoundaryRow(st + 1)) - else None + val startKeyRow = startTimestamp.map { st => + createScanBoundaryRow(st + 1) } // endTimestamp is inclusive, so we use endTimestamp + 1 as the exclusive upper bound. // When endTimestamp == Long.MaxValue we cannot add 1, so endKeyRow is None. This is diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index c9b81b12b735e..d63eba59ff8fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -1183,10 +1183,9 @@ class SymmetricHashJoinStateManagerEventTimeInValueSuite // --- overflow boundary: startTimestamp = Some(Long.MaxValue) --- Seq(10, 20).foreach(append(40, _)) // startTimestamp=Long.MaxValue (exclusive) means everything <= Long.MaxValue was already - // evicted. Since startKeyRow falls back to None, endTimestamp=50 bounds the scan. - // All remaining entries (10, 20, 30) have timestamps <= 50, so they are evicted. - assert(evictByTs.evictByTimestamp(50, Some(Long.MaxValue)) === 3) - assert(get(40) === Seq.empty) + // evicted. Nothing can remain, so the scan returns an empty iterator immediately. + assert(evictByTs.evictByTimestamp(50, Some(Long.MaxValue)) === 0) + assert(get(40) === Seq(10, 20, 30)) } } From a0d896a739d40a5034e5fcfc046bed4dd81f0bb5 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 14 Apr 2026 13:17:50 +0900 Subject: [PATCH 11/12] simple assertion --- .../stateful/join/SymmetricHashJoinStateManager.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index 66f0ee5f80a56..d3845dd25a1f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -709,6 +709,11 @@ class SymmetricHashJoinStateManagerV4( val unsafeRowPair = iter.next() val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key) + if (useRangeScan) { + assert(ts >= minTs && ts <= maxTs, + s"rangeScan returned unexpected timestamp $ts outside [$minTs, $maxTs]") + } + if (ts > maxTs) { pastUpperBound = true getNext() From f3a321372ff656c0abf7b54d1f16890000fa8770 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 14 Apr 2026 14:30:01 +0900 Subject: [PATCH 12/12] empty commit to retrigger CI