From f772b2c107ca194a45d0185ace285acf4520bbea Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 8 Apr 2026 04:06:02 +0900 Subject: [PATCH 1/5] [SPARK-XXXXX] Support stream-stream non-outer join in Update mode --- .../UnsupportedOperationChecker.scala | 18 ++++- .../analysis/UnsupportedOperationsSuite.scala | 19 +++-- .../join/StreamingSymmetricHashJoinExec.scala | 13 ++- .../runtime/IncrementalExecution.scala | 3 +- .../sql/streaming/StreamingJoinSuite.scala | 80 ++++++++++++++++++- 5 files changed, 120 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 60b952b285e13..61e29a07fb0ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -441,9 +441,21 @@ object UnsupportedOperationChecker extends Logging { } case j @ Join(left, right, joinType, condition, _) => - if (left.isStreaming && right.isStreaming && outputMode != InternalOutputModes.Append) { - throwError("Join between two streaming DataFrames/Datasets is not supported" + - s" in ${outputMode} output mode, only in Append output mode") + if (left.isStreaming && right.isStreaming) { + joinType match { + case LeftOuter | RightOuter | FullOuter => + if (outputMode != InternalOutputModes.Append) { + throwError(s"$joinType join between two streaming DataFrames/Datasets" + + s" is not supported in ${outputMode} output mode, only in Append output mode") + } + case _ => + if (outputMode != InternalOutputModes.Append && + outputMode != InternalOutputModes.Update) { + throwError(s"$joinType join between two streaming DataFrames/Datasets" + + s" is not supported in ${outputMode} output mode, only in Append and Update " + + "output modes") + } + } } joinType match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 425df0856a58a..40cc7b5f04d34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -370,9 +370,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { testBinaryOperationInStreamingPlan( "inner join in update mode", _.join(_, joinType = Inner), - outputMode = Update, - streamStreamSupported = false, - expectedMsg = "is not supported in Update output mode") + outputMode = Update) // Full outer joins: stream-batch/batch-stream join are not allowed, // and stream-stream join is allowed 'conditionally' - see below check @@ -403,16 +401,25 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { streamStreamSupported = false, expectedMsg = "RightOuter join") - // Left outer, right outer, full outer, left semi joins - Seq(LeftOuter, RightOuter, FullOuter, LeftSemi).foreach { joinType => - // Update mode not allowed + // Left outer, right outer, full outer joins: Update mode not allowed + Seq(LeftOuter, RightOuter, FullOuter).foreach { joinType => assertNotSupportedInStreamingPlan( s"$joinType join with stream-stream relations and update mode", streamRelation.join(streamRelation, joinType = joinType, condition = Some(attribute === attribute)), OutputMode.Update(), Seq("is not supported in Update output mode")) + } + + // LeftSemi join: Update mode allowed (equivalent to Append mode for non-outer joins) + assertSupportedInStreamingPlan( + s"LeftSemi join with stream-stream relations and update mode", + streamRelation.join(streamRelation, joinType = LeftSemi, + condition = Some(attributeWithWatermark === attribute)), + OutputMode.Update()) + // Left outer, right outer, full outer, left semi joins + Seq(LeftOuter, RightOuter, FullOuter, LeftSemi).foreach { joinType => // Complete mode not allowed assertNotSupportedInStreamingPlan( s"$joinType join with stream-stream relations and complete mode", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala index 1c50e6802c323..ddb3e7c862b1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, GenericInternalRow, JoinedRow, Literal, Predicate, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetric @@ -35,6 +36,7 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.join.Streamin import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager.KeyToValuePair import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.{SessionState, SQLConf} +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} @@ -142,7 +144,9 @@ case class StreamingSymmetricHashJoinExec( stateWatermarkPredicates: JoinStateWatermarkPredicates, stateFormatVersion: Int, left: SparkPlan, - right: SparkPlan) extends BinaryExecNode with StateStoreWriter with SchemaValidationUtils { + right: SparkPlan, + outputMode: Option[OutputMode] = None) + extends BinaryExecNode with StateStoreWriter with SchemaValidationUtils { def this( leftKeys: Seq[Expression], @@ -184,6 +188,13 @@ case class StreamingSymmetricHashJoinExec( joinType == LeftSemi, errorMessageForJoinType) + outputMode.foreach { mode => + if (mode == InternalOutputModes.Update) { + require(joinType == Inner || joinType == LeftSemi, + s"Update output mode is not supported for stream-stream $joinType join") + } + } + // The assertion against join keys is same as hash join for batch query. require(leftKeys.length == rightKeys.length && leftKeys.map(_.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala index 169ab6f606dae..8165117a028d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala @@ -411,7 +411,8 @@ class IncrementalExecution( j.copy( stateInfo = Some(nextStatefulOperationStateInfo()), eventTimeWatermarkForLateEvents = None, - eventTimeWatermarkForEviction = None + eventTimeWatermarkForEviction = None, + outputMode = Some(outputMode) ) case l: StreamingGlobalLimitExec => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 7dc228feaff81..e1cb2e3ca97a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -30,7 +30,7 @@ import org.scalatest.{BeforeAndAfter, Tag} import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions @@ -352,6 +352,28 @@ abstract class StreamingInnerJoinBase extends StreamingJoinSuite { ) } + // Stream-stream non-outer join produces the same behavior between Append mode and Update mode. + // We only run a sanity test here rather than replicating the full Append mode test suite. + test("stream stream inner join with Update mode on non-time column") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF().select($"value" as "key", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "key", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, "key") + + testStream(joined, OutputMode.Update())( + AddData(input1, 1), + CheckAnswer(), + AddData(input2, 1, 10), + CheckNewAnswer((1, 2, 3)), + AddData(input1, 10), + CheckNewAnswer((10, 20, 30)), + AddData(input2, 1), + CheckNewAnswer((1, 2, 3)) + ) + } + test("stream stream inner join on windows - without watermark") { val input1 = MemoryStream[Int] val input2 = MemoryStream[Int] @@ -669,7 +691,7 @@ abstract class StreamingInnerJoinBase extends StreamingJoinSuite { assert(query.lastExecution.executedPlan.collect { case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, _, ShuffleExchangeExec(opA: HashPartitioning, _, _, _), - ShuffleExchangeExec(opB: HashPartitioning, _, _, _)) + ShuffleExchangeExec(opB: HashPartitioning, _, _, _), _) if partitionExpressionsColumns(opA.expressions) === Seq("a", "b") && partitionExpressionsColumns(opB.expressions) === Seq("a", "b") && opA.numPartitions == numPartitions && opB.numPartitions == numPartitions => j @@ -1242,6 +1264,25 @@ abstract class StreamingOuterJoinBase extends StreamingJoinSuite { import testImplicits._ import org.apache.spark.sql.functions._ + Seq("left_outer", "right_outer").foreach { joinType => + test(s"stream-stream $joinType join does not support Update mode") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF().select($"value" as "key", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "key", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, Seq("key"), joinType) + + val e = intercept[AnalysisException] { + testStream(joined, OutputMode.Update())( + AddData(input1, 1), + CheckAnswer() + ) + } + assert(e.getMessage.contains("is not supported in Update output mode")) + } + } + test("left outer early state exclusion on left") { withTempDir { checkpointDir => val (leftInput, rightInput, joined) = setupWindowedJoinWithLeftCondition("left_outer") @@ -1954,6 +1995,25 @@ abstract class StreamingOuterJoinSuite extends StreamingOuterJoinBase { @SlowSQLTest abstract class StreamingFullOuterJoinBase extends StreamingJoinSuite { + import testImplicits._ + + test("stream-stream full outer join does not support Update mode") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF().select($"value" as "key", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "key", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, Seq("key"), "full_outer") + + val e = intercept[AnalysisException] { + testStream(joined, OutputMode.Update())( + AddData(input1, 1), + CheckAnswer() + ) + } + assert(e.getMessage.contains("is not supported in Update output mode")) + } + test("windowed full outer join") { withTempDir { checkpointDir => val (leftInput, rightInput, joined) = setupWindowedJoin("full_outer") @@ -2176,6 +2236,22 @@ abstract class StreamingLeftSemiJoinBase extends StreamingJoinSuite { import testImplicits._ + // Stream-stream non-outer join produces the same behavior between Append mode and Update mode. + // We only run a sanity test here rather than replicating the full Append mode test suite. + test("windowed left semi join with Update mode") { + withTempDir { checkpointDir => + val (leftInput, rightInput, joined) = setupWindowedJoin("left_semi") + + testStream(joined, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), + CheckNewAnswer(Row(3, 10, 6), Row(4, 10, 8), Row(5, 10, 10)), + MultiAddData(leftInput, 21)(rightInput, 22), + CheckNewAnswer() + ) + } + } + test("windowed left semi join") { withTempDir { checkpointDir => val (leftInput, rightInput, joined) = setupWindowedJoin("left_semi") From fab4a3afecc5a3ad4edfa8ed7acb94922589cad6 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 9 Apr 2026 05:50:47 +0900 Subject: [PATCH 2/5] Add tests for ability of multiple stateful operator in update mode with stream-stream inner join --- .../analysis/UnsupportedOperationsSuite.scala | 35 +++++++++ .../MultiStatefulOperatorsSuite.scala | 74 +++++++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 40cc7b5f04d34..dc429c87346db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -678,6 +678,21 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { outputMode = Append) } + assertPassOnGlobalWatermarkLimit( + "streaming aggregation after stream-stream inner join in Update mode", + streamRelation.join(streamRelation, joinType = Inner, + condition = Some(attributeWithWatermark === attribute)) + .groupBy("a")(count("*")), + outputMode = Update) + + assertFailOnGlobalWatermarkLimit( + "streaming aggregation on both sides followed by stream-stream inner join in Update mode", + streamRelation.groupBy("a")(count("*")).join( + streamRelation.groupBy("a")(count("*")), + joinType = Inner, + condition = Some(attributeWithWatermark === attribute)), + outputMode = Update) + // Cogroup: only batch-batch is allowed testBinaryOperationInStreamingPlan( "cogroup", @@ -858,6 +873,26 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, Deduplicate(Seq(attribute), streamRelation)), outputMode = Append) + + Seq(Append, Update).foreach { outputMode => + assertPassOnGlobalWatermarkLimit( + s"stream-stream inner join with deduplicate on both sides " + + s"(with event-time) in ${outputMode} mode", + Deduplicate(Seq(attributeWithWatermark), streamRelation).join( + Deduplicate(Seq(attributeWithWatermark), streamRelation), + joinType = Inner, + condition = Some(attributeWithWatermark === attribute)), + outputMode = outputMode) + + assertPassOnGlobalWatermarkLimit( + s"stream-stream inner join with deduplicate on both sides " + + s"(without event-time) in ${outputMode} mode", + Deduplicate(Seq(attribute), streamRelation).join( + Deduplicate(Seq(attribute), streamRelation), + joinType = Inner, + condition = Some(attributeWithWatermark === attribute)), + outputMode = outputMode) + } } /* diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala index 21bf370f82a5f..f24d61af1cabd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala @@ -934,6 +934,80 @@ class MultiStatefulOperatorsSuite ) } + test("dedup on both sides -> stream-stream inner join, update mode") { + val input1 = MemoryStream[Int] + val inputDF1 = input1.toDF() + .withColumnRenamed("value", "value1") + .withColumn("eventTime1", timestamp_seconds($"value1")) + .withWatermark("eventTime1", "10 seconds") + .dropDuplicates("value1", "eventTime1") + + val input2 = MemoryStream[Int] + val inputDF2 = input2.toDF() + .withColumnRenamed("value", "value2") + .withColumn("eventTime2", timestamp_seconds($"value2")) + .withWatermark("eventTime2", "10 seconds") + .dropDuplicates("value2", "eventTime2") + + val stream = inputDF1.join(inputDF2, expr("eventTime1 = eventTime2"), "inner") + .select($"value1", $"value2") + + testStream(stream, OutputMode.Update())( + // Send data with duplicates: input1 has duplicate 1, input2 has duplicate 2 + MultiAddData(input1, 1, 2, 3, 1)(input2, 1, 2, 3, 2), + // dedup1: filters second 1, passes 1, 2, 3 + // dedup2: filters second 2, passes 1, 2, 3 + // join: (1, 1), (2, 2), (3, 3) + CheckNewAnswer((1, 1), (2, 2), (3, 3)), + + // Send overlapping values: 1, 2 on left are dups from batch 1; 2, 3 on right are dups + MultiAddData(input1, 1, 2, 4)(input2, 2, 3, 4), + // dedup1: filters 1, 2 (already seen), passes only 4 + // dedup2: filters 2, 3 (already seen), passes only 4 + // join: only (4, 4) matches + CheckNewAnswer((4, 4)) + ) + } + + test("stream-stream inner join -> window agg, update mode") { + val input1 = MemoryStream[Int] + val inputDF1 = input1.toDF() + .withColumnRenamed("value", "value1") + .withColumn("eventTime1", timestamp_seconds($"value1")) + .withWatermark("eventTime1", "0 seconds") + + val input2 = MemoryStream[Int] + val inputDF2 = input2.toDF() + .withColumnRenamed("value", "value2") + .withColumn("eventTime2", timestamp_seconds($"value2")) + .withWatermark("eventTime2", "0 seconds") + + val stream = inputDF1.join(inputDF2, expr("eventTime1 = eventTime2"), "inner") + .groupBy(window($"eventTime1", "5 seconds").as("window")) + .agg(count("*").as("count")) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(stream, OutputMode.Update())( + MultiAddData(input1, 1, 2)(input2, 1, 2), + // join output: (1, 1), (2, 2) + // agg: [0, 5) count = 2 + CheckNewAnswer((0, 2)), + + // Add more data to the same window [0, 5) + MultiAddData(input1, 3, 4)(input2, 3, 4), + // join output: (3, 3), (4, 4) + // agg: [0, 5) count = 2 + 2 = 4 + // Update mode re-emits the window with updated count + CheckNewAnswer((0, 4)), + + MultiAddData(input1, 5 to 8: _*)(input2, 5 to 8: _*), + // join output: (5, 5), (6, 6), (7, 7), (8, 8) + // agg: [5, 10) count = 4 + // Only the new/updated window is emitted + CheckNewAnswer((5, 4)) + ) + } + private def assertNumStateRows(numTotalRows: Seq[Long]): AssertOnQuery = AssertOnQuery { q => q.processAllAvailable() val progressWithData = q.recentProgress.lastOption.get From 7160f2b9fed0acf0dd1b74f24b035103f5caa4b0 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 9 Apr 2026 10:34:30 +0900 Subject: [PATCH 3/5] Review comments --- .../UnsupportedOperationChecker.scala | 5 +- .../MultiStatefulOperatorsSuite.scala | 60 +- .../sql/streaming/StreamingJoinSuite.scala | 1041 ++++++++--------- 3 files changed, 547 insertions(+), 559 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 61e29a07fb0ea..1883042f1c18a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -443,18 +443,21 @@ object UnsupportedOperationChecker extends Logging { case j @ Join(left, right, joinType, condition, _) => if (left.isStreaming && right.isStreaming) { joinType match { + // The behavior for unmatched rows in outer joins with update mode + // hasn't been defined yet. case LeftOuter | RightOuter | FullOuter => if (outputMode != InternalOutputModes.Append) { throwError(s"$joinType join between two streaming DataFrames/Datasets" + s" is not supported in ${outputMode} output mode, only in Append output mode") } - case _ => + case _: InnerLike | LeftSemi => if (outputMode != InternalOutputModes.Append && outputMode != InternalOutputModes.Update) { throwError(s"$joinType join between two streaming DataFrames/Datasets" + s" is not supported in ${outputMode} output mode, only in Append and Update " + "output modes") } + case _ => // we will throw an error in the next pattern match } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala index f24d61af1cabd..7fbeea180f86b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala @@ -934,39 +934,33 @@ class MultiStatefulOperatorsSuite ) } - test("dedup on both sides -> stream-stream inner join, update mode") { - val input1 = MemoryStream[Int] - val inputDF1 = input1.toDF() - .withColumnRenamed("value", "value1") - .withColumn("eventTime1", timestamp_seconds($"value1")) - .withWatermark("eventTime1", "10 seconds") - .dropDuplicates("value1", "eventTime1") - - val input2 = MemoryStream[Int] - val inputDF2 = input2.toDF() - .withColumnRenamed("value", "value2") - .withColumn("eventTime2", timestamp_seconds($"value2")) - .withWatermark("eventTime2", "10 seconds") - .dropDuplicates("value2", "eventTime2") - - val stream = inputDF1.join(inputDF2, expr("eventTime1 = eventTime2"), "inner") - .select($"value1", $"value2") - - testStream(stream, OutputMode.Update())( - // Send data with duplicates: input1 has duplicate 1, input2 has duplicate 2 - MultiAddData(input1, 1, 2, 3, 1)(input2, 1, 2, 3, 2), - // dedup1: filters second 1, passes 1, 2, 3 - // dedup2: filters second 2, passes 1, 2, 3 - // join: (1, 1), (2, 2), (3, 3) - CheckNewAnswer((1, 1), (2, 2), (3, 3)), - - // Send overlapping values: 1, 2 on left are dups from batch 1; 2, 3 on right are dups - MultiAddData(input1, 1, 2, 4)(input2, 2, 3, 4), - // dedup1: filters 1, 2 (already seen), passes only 4 - // dedup2: filters 2, 3 (already seen), passes only 4 - // join: only (4, 4) matches - CheckNewAnswer((4, 4)) - ) + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test(s"dedup on both sides -> stream-stream inner join, ${outputMode} mode") { + val input1 = MemoryStream[Int] + val inputDF1 = input1.toDF() + .withColumnRenamed("value", "value1") + .withColumn("eventTime1", timestamp_seconds($"value1")) + .withWatermark("eventTime1", "10 seconds") + .dropDuplicates("value1", "eventTime1") + + val input2 = MemoryStream[Int] + val inputDF2 = input2.toDF() + .withColumnRenamed("value", "value2") + .withColumn("eventTime2", timestamp_seconds($"value2")) + .withWatermark("eventTime2", "10 seconds") + .dropDuplicates("value2", "eventTime2") + + val stream = inputDF1.join(inputDF2, expr("eventTime1 = eventTime2"), "inner") + .select($"value1", $"value2") + + testStream(stream, outputMode)( + MultiAddData(input1, 1, 2, 3, 1)(input2, 1, 2, 3, 2), + CheckNewAnswer((1, 1), (2, 2), (3, 3)), + + MultiAddData(input1, 1, 2, 4)(input2, 2, 3, 4), + CheckNewAnswer((4, 4)) + ) + } } test("stream-stream inner join -> window agg, update mode") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index e1cb2e3ca97a2..4e23e64423b06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -323,242 +323,231 @@ abstract class StreamingJoinSuite abstract class StreamingInnerJoinBase extends StreamingJoinSuite { import testImplicits._ - test("stream stream inner join on non-time column") { - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] - - val df1 = input1.toDF().select($"value" as "key", ($"value" * 2) as "leftValue") - val df2 = input2.toDF().select($"value" as "key", ($"value" * 3) as "rightValue") - val joined = df1.join(df2, "key") - - testStream(joined)( - AddData(input1, 1), - CheckAnswer(), - AddData(input2, 1, 10), // 1 arrived on input1 first, then input2, should join - CheckNewAnswer((1, 2, 3)), - AddData(input1, 10), // 10 arrived on input2 first, then input1, should join - CheckNewAnswer((10, 20, 30)), - AddData(input2, 1), // another 1 in input2 should join with 1 input1 - CheckNewAnswer((1, 2, 3)), - StopStream, - StartStream(), - AddData(input1, 1), // multiple 1s should be kept in state causing multiple (1, 2, 3) - CheckNewAnswer((1, 2, 3), (1, 2, 3)), - StopStream, - StartStream(), - AddData(input1, 100), - AddData(input2, 100), - CheckNewAnswer((100, 200, 300)) - ) - } - - // Stream-stream non-outer join produces the same behavior between Append mode and Update mode. - // We only run a sanity test here rather than replicating the full Append mode test suite. - test("stream stream inner join with Update mode on non-time column") { - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test(s"stream stream inner join on non-time column - $outputMode") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] - val df1 = input1.toDF().select($"value" as "key", ($"value" * 2) as "leftValue") - val df2 = input2.toDF().select($"value" as "key", ($"value" * 3) as "rightValue") - val joined = df1.join(df2, "key") + val df1 = input1.toDF().select($"value" as "key", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "key", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, "key") - testStream(joined, OutputMode.Update())( - AddData(input1, 1), - CheckAnswer(), - AddData(input2, 1, 10), - CheckNewAnswer((1, 2, 3)), - AddData(input1, 10), - CheckNewAnswer((10, 20, 30)), - AddData(input2, 1), - CheckNewAnswer((1, 2, 3)) - ) + testStream(joined, outputMode)( + AddData(input1, 1), + CheckAnswer(), + AddData(input2, 1, 10), // 1 arrived on input1 first, then input2, should join + CheckNewAnswer((1, 2, 3)), + AddData(input1, 10), // 10 arrived on input2 first, then input1, should join + CheckNewAnswer((10, 20, 30)), + AddData(input2, 1), // another 1 in input2 should join with 1 input1 + CheckNewAnswer((1, 2, 3)), + StopStream, + StartStream(), + AddData(input1, 1), // multiple 1s should be kept in state causing multiple (1, 2, 3) + CheckNewAnswer((1, 2, 3), (1, 2, 3)), + StopStream, + StartStream(), + AddData(input1, 100), + AddData(input2, 100), + CheckNewAnswer((100, 200, 300)) + ) + } } - test("stream stream inner join on windows - without watermark") { - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test(s"stream stream inner join on windows - without watermark - $outputMode") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] - val df1 = input1.toDF() - .select($"value" as "key", timestamp_seconds($"value") as "timestamp", - ($"value" * 2) as "leftValue") - .select($"key", window($"timestamp", "10 second"), $"leftValue") + val df1 = input1.toDF() + .select($"value" as "key", timestamp_seconds($"value") as "timestamp", + ($"value" * 2) as "leftValue") + .select($"key", window($"timestamp", "10 second"), $"leftValue") - val df2 = input2.toDF() - .select($"value" as "key", timestamp_seconds($"value") as "timestamp", - ($"value" * 3) as "rightValue") - .select($"key", window($"timestamp", "10 second"), $"rightValue") + val df2 = input2.toDF() + .select($"value" as "key", timestamp_seconds($"value") as "timestamp", + ($"value" * 3) as "rightValue") + .select($"key", window($"timestamp", "10 second"), $"rightValue") - val joined = df1.join(df2, Seq("key", "window")) - .select($"key", $"window.end".cast("long"), $"leftValue", $"rightValue") + val joined = df1.join(df2, Seq("key", "window")) + .select($"key", $"window.end".cast("long"), $"leftValue", $"rightValue") - testStream(joined)( - AddData(input1, 1), - CheckNewAnswer(), - AddData(input2, 1), - CheckNewAnswer((1, 10, 2, 3)), - StopStream, - StartStream(), - AddData(input1, 25), - CheckNewAnswer(), - StopStream, - StartStream(), - AddData(input2, 25), - CheckNewAnswer((25, 30, 50, 75)), - AddData(input1, 1), - CheckNewAnswer((1, 10, 2, 3)), // State for 1 still around as there is no watermark - StopStream, - StartStream(), - AddData(input1, 5), - CheckNewAnswer(), - AddData(input2, 5), - CheckNewAnswer((5, 10, 10, 15)) // No filter by any watermark - ) + testStream(joined, outputMode)( + AddData(input1, 1), + CheckNewAnswer(), + AddData(input2, 1), + CheckNewAnswer((1, 10, 2, 3)), + StopStream, + StartStream(), + AddData(input1, 25), + CheckNewAnswer(), + StopStream, + StartStream(), + AddData(input2, 25), + CheckNewAnswer((25, 30, 50, 75)), + AddData(input1, 1), + CheckNewAnswer((1, 10, 2, 3)), // State for 1 still around as there is no watermark + StopStream, + StartStream(), + AddData(input1, 5), + CheckNewAnswer(), + AddData(input2, 5), + CheckNewAnswer((5, 10, 10, 15)) // No filter by any watermark + ) + } } - test("stream stream inner join with time range - with watermark - one side condition") { - import org.apache.spark.sql.functions._ + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test("stream stream inner join with time range - with watermark" + + s" - one side condition - $outputMode") { + import org.apache.spark.sql.functions._ - val leftInput = MemoryStream[(Int, Int)] - val rightInput = MemoryStream[(Int, Int)] + val leftInput = MemoryStream[(Int, Int)] + val rightInput = MemoryStream[(Int, Int)] - val df1 = leftInput.toDF().toDF("leftKey", "time") - .select($"leftKey", timestamp_seconds($"time") as "leftTime", - ($"leftKey" * 2) as "leftValue") - .withWatermark("leftTime", "10 seconds") + val df1 = leftInput.toDF().toDF("leftKey", "time") + .select($"leftKey", timestamp_seconds($"time") as "leftTime", + ($"leftKey" * 2) as "leftValue") + .withWatermark("leftTime", "10 seconds") - val df2 = rightInput.toDF().toDF("rightKey", "time") - .select($"rightKey", timestamp_seconds($"time") as "rightTime", - ($"rightKey" * 3) as "rightValue") - .withWatermark("rightTime", "10 seconds") + val df2 = rightInput.toDF().toDF("rightKey", "time") + .select($"rightKey", timestamp_seconds($"time") as "rightTime", + ($"rightKey" * 3) as "rightValue") + .withWatermark("rightTime", "10 seconds") - val joined = - df1.join(df2, expr("leftKey = rightKey AND leftTime < rightTime - interval 5 seconds")) - .select($"leftKey", $"leftTime".cast("int"), $"rightTime".cast("int")) + val joined = + df1.join(df2, expr("leftKey = rightKey AND leftTime < rightTime - interval 5 seconds")) + .select($"leftKey", $"leftTime".cast("int"), $"rightTime".cast("int")) - testStream(joined)( - AddData(leftInput, (1, 5)), - CheckAnswer(), - AddData(rightInput, (1, 11)), - CheckNewAnswer((1, 5, 11)), - AddData(rightInput, (1, 10)), - CheckNewAnswer(), // no match as leftTime 5 is not < rightTime 10 - 5 - assertNumStateRows(total = 3, updated = 3), - - // Increase event time watermark to 20s by adding data with time = 30s on both inputs - AddData(leftInput, (1, 3), (1, 30)), - CheckNewAnswer((1, 3, 10), (1, 3, 11)), - assertNumStateRows(total = 5, updated = 2), - AddData(rightInput, (0, 30)), - CheckNewAnswer(), + testStream(joined, outputMode)( + AddData(leftInput, (1, 5)), + CheckAnswer(), + AddData(rightInput, (1, 11)), + CheckNewAnswer((1, 5, 11)), + AddData(rightInput, (1, 10)), + CheckNewAnswer(), // no match as leftTime 5 is not < rightTime 10 - 5 + assertNumStateRows(total = 3, updated = 3), - // event time watermark: max event time - 10 ==> 30 - 10 = 20 - // so left side going to only receive data where leftTime > 20 - // right side state constraint: 20 < leftTime < rightTime - 5 ==> rightTime > 25 - // right state where rightTime <= 25 will be cleared, (1, 11) and (1, 10) removed - assertNumStateRows(total = 4, updated = 1), - - // New data to right input should match with left side (1, 3) and (1, 5), as left state should - // not be cleared. But rows rightTime <= 20 should be filtered due to event time watermark and - // state rows with rightTime <= 25 should be removed from state. - // (1, 20) ==> filtered by event time watermark = 20 - // (1, 21) ==> passed filter, matched with left (1, 3) and (1, 5), not added to state - // as 21 < state watermark = 25 - // (1, 28) ==> passed filter, matched with left (1, 3) and (1, 5), added to state - AddData(rightInput, (1, 20), (1, 21), (1, 28)), - CheckNewAnswer((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), - assertNumStateRows(total = 5, updated = 1, droppedByWatermark = 1), - - // New data to left input with leftTime <= 20 should be filtered due to event time watermark - AddData(leftInput, (1, 20), (1, 21)), - CheckNewAnswer((1, 21, 28)), - assertNumStateRows(total = 6, updated = 1, droppedByWatermark = 1) - ) - } + // Increase event time watermark to 20s by adding data with time = 30s on both inputs + AddData(leftInput, (1, 3), (1, 30)), + CheckNewAnswer((1, 3, 10), (1, 3, 11)), + assertNumStateRows(total = 5, updated = 2), + AddData(rightInput, (0, 30)), + CheckNewAnswer(), - test("stream stream inner join with time range - with watermark - two side conditions") { - import org.apache.spark.sql.functions._ + // event time watermark: max event time - 10 ==> 30 - 10 = 20 + // so left side going to only receive data where leftTime > 20 + // right side state constraint: 20 < leftTime < rightTime - 5 ==> rightTime > 25 + // right state where rightTime <= 25 will be cleared, (1, 11) and (1, 10) removed + assertNumStateRows(total = 4, updated = 1), + + // New data to right input should match with left side (1, 3) and (1, 5), as left state + // should not be cleared. But rows rightTime <= 20 should be filtered due to event time + // watermark and state rows with rightTime <= 25 should be removed from state. + // (1, 20) ==> filtered by event time watermark = 20 + // (1, 21) ==> passed filter, matched with left (1, 3) and (1, 5), not added to state + // as 21 < state watermark = 25 + // (1, 28) ==> passed filter, matched with left (1, 3) and (1, 5), added to state + AddData(rightInput, (1, 20), (1, 21), (1, 28)), + CheckNewAnswer((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), + assertNumStateRows(total = 5, updated = 1, droppedByWatermark = 1), + + // New data to left input with leftTime <= 20 should be filtered due to event time watermark + AddData(leftInput, (1, 20), (1, 21)), + CheckNewAnswer((1, 21, 28)), + assertNumStateRows(total = 6, updated = 1, droppedByWatermark = 1) + ) + } + } - val leftInput = MemoryStream[(Int, Int)] - val rightInput = MemoryStream[(Int, Int)] + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test("stream stream inner join with time range - with watermark" + + s" - two side conditions - $outputMode") { + import org.apache.spark.sql.functions._ - val df1 = leftInput.toDF().toDF("leftKey", "time") - .select($"leftKey", timestamp_seconds($"time") as "leftTime", - ($"leftKey" * 2) as "leftValue") - .withWatermark("leftTime", "20 seconds") + val leftInput = MemoryStream[(Int, Int)] + val rightInput = MemoryStream[(Int, Int)] - val df2 = rightInput.toDF().toDF("rightKey", "time") - .select($"rightKey", timestamp_seconds($"time") as "rightTime", - ($"rightKey" * 3) as "rightValue") - .withWatermark("rightTime", "30 seconds") - - val condition = expr( - "leftKey = rightKey AND " + - "leftTime BETWEEN rightTime - interval 10 seconds AND rightTime + interval 5 seconds") - - // This translates to leftTime <= rightTime + 5 seconds AND leftTime >= rightTime - 10 seconds - // So given leftTime, rightTime has to be BETWEEN leftTime - 5 seconds AND leftTime + 10 seconds - // - // =============== * ======================== * ============================== * ==> leftTime - // | | | - // |<---- 5s -->|<------ 10s ------>| |<------ 10s ------>|<---- 5s -->| - // | | | - // == * ============================== * =========>============== * ===============> rightTime - // - // E.g. - // if rightTime = 60, then it matches only leftTime = [50, 65] - // if leftTime = 20, then it match only with rightTime = [15, 30] - // - // State value predicates - // left side: - // values allowed: leftTime >= rightTime - 10s ==> leftTime > eventTimeWatermark - 10 - // drop state where leftTime < eventTime - 10 - // right side: - // values allowed: rightTime >= leftTime - 5s ==> rightTime > eventTimeWatermark - 5 - // drop state where rightTime < eventTime - 5 + val df1 = leftInput.toDF().toDF("leftKey", "time") + .select($"leftKey", timestamp_seconds($"time") as "leftTime", + ($"leftKey" * 2) as "leftValue") + .withWatermark("leftTime", "20 seconds") - val joined = - df1.join(df2, condition).select($"leftKey", $"leftTime".cast("int"), - $"rightTime".cast("int")) + val df2 = rightInput.toDF().toDF("rightKey", "time") + .select($"rightKey", timestamp_seconds($"time") as "rightTime", + ($"rightKey" * 3) as "rightValue") + .withWatermark("rightTime", "30 seconds") - testStream(joined)( - // If leftTime = 20, then it match only with rightTime = [15, 30] - AddData(leftInput, (1, 20)), - CheckAnswer(), - AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), - CheckNewAnswer((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), - assertNumStateRows(total = 7, updated = 7), + val condition = expr( + "leftKey = rightKey AND " + + "leftTime BETWEEN rightTime - interval 10 seconds AND rightTime + interval 5 seconds") - // If rightTime = 60, then it matches only leftTime = [50, 65] - AddData(rightInput, (1, 60)), - CheckNewAnswer(), // matches with nothing on the left - AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), - CheckNewAnswer((1, 50, 60), (1, 65, 60)), - - // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 - // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) - // Should drop < 20 from left, i.e., none - // Right state value watermark = 30 - 5 = slightly less than 25 (since condition has <=) - // Should drop < 25 from the right, i.e., 14 and 15 - assertNumStateRows(total = 10, updated = 5), // 12 - 2 removed - - AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to state - CheckNewAnswer((1, 31, 26), (1, 31, 30), (1, 31, 31)), - assertNumStateRows(total = 11, updated = 1, droppedByWatermark = 1), // only 31 added - - // Advance the watermark - AddData(rightInput, (1, 80)), - CheckNewAnswer(), - // Event time watermark = min(left: 66 - delay 20 = 46, right: 80 - delay 30 = 50) = 46 - // Left state value watermark = 46 - 10 = slightly less than 36 (since condition has <=) - // Should drop < 36 from left, i.e., 20, 31 (30 was not added) - // Right state value watermark = 46 - 5 = slightly less than 41 (since condition has <=) - // Should drop < 41 from the right, i.e., 25, 26, 30, 31 - assertNumStateRows(total = 6, updated = 1), // 12 - 6 removed - - AddData(rightInput, (1, 46), (1, 50)), // 46 should not be processed or added to state - CheckNewAnswer((1, 49, 50), (1, 50, 50)), - assertNumStateRows(total = 7, updated = 1, droppedByWatermark = 1) // 50 added - ) + // This translates to leftTime <= rightTime + 5 seconds AND + // leftTime >= rightTime - 10 seconds. So given leftTime, rightTime has to be + // BETWEEN leftTime - 5 seconds AND leftTime + 10 seconds + // + // ============ * ==================== * ======================== * ==> leftTime + // | | | + // |<--- 5s --->|<----- 10s ----->| |<----- 10s ----->|<- 5s->| + // | | | + // * ============================= * =================== * ========> rightTime + // + // E.g. + // if rightTime = 60, then it matches only leftTime = [50, 65] + // if leftTime = 20, then it match only with rightTime = [15, 30] + // + // State value predicates + // left side: + // values allowed: leftTime >= rightTime - 10s ==> leftTime > eventTimeWatermark - 10 + // drop state where leftTime < eventTime - 10 + // right side: + // values allowed: rightTime >= leftTime - 5s ==> rightTime > eventTimeWatermark - 5 + // drop state where rightTime < eventTime - 5 + + val joined = + df1.join(df2, condition).select($"leftKey", $"leftTime".cast("int"), + $"rightTime".cast("int")) + + testStream(joined, outputMode)( + // If leftTime = 20, then it match only with rightTime = [15, 30] + AddData(leftInput, (1, 20)), + CheckAnswer(), + AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), + CheckNewAnswer((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), + assertNumStateRows(total = 7, updated = 7), + + // If rightTime = 60, then it matches only leftTime = [50, 65] + AddData(rightInput, (1, 60)), + CheckNewAnswer(), // matches with nothing on the left + AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), + CheckNewAnswer((1, 50, 60), (1, 65, 60)), + + // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 + // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) + // Should drop < 20 from left, i.e., none + // Right state value watermark = 30 - 5 = slightly less than 25 (since condition has <=) + // Should drop < 25 from the right, i.e., 14 and 15 + assertNumStateRows(total = 10, updated = 5), // 12 - 2 removed + + AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to state + CheckNewAnswer((1, 31, 26), (1, 31, 30), (1, 31, 31)), + assertNumStateRows(total = 11, updated = 1, droppedByWatermark = 1), // only 31 added + + // Advance the watermark + AddData(rightInput, (1, 80)), + CheckNewAnswer(), + // Event time watermark = min(left: 66 - delay 20 = 46, right: 80 - delay 30 = 50) = 46 + // Left state value watermark = 46 - 10 = slightly less than 36 (since condition has <=) + // Should drop < 36 from left, i.e., 20, 31 (30 was not added) + // Right state value watermark = 46 - 5 = slightly less than 41 (since condition has <=) + // Should drop < 41 from the right, i.e., 25, 26, 30, 31 + assertNumStateRows(total = 6, updated = 1), // 12 - 6 removed + + AddData(rightInput, (1, 46), (1, 50)), // 46 should not be processed or added to state + CheckNewAnswer((1, 49, 50), (1, 50, 50)), + assertNumStateRows(total = 7, updated = 1, droppedByWatermark = 1) // 50 added + ) + } } testQuietly("stream stream inner join without equality predicate") { @@ -578,25 +567,27 @@ abstract class StreamingInnerJoinBase extends StreamingJoinSuite { assert(e.toString.contains("Stream-stream join without equality predicate is not supported")) } - test("stream stream self join") { - val input = MemoryStream[Int] - val df = input.toDF() - val join = - df.select($"value" % 5 as "key", $"value").join( - df.select($"value" % 5 as "key", $"value"), "key") - - testStream(join)( - AddData(input, 1, 2), - CheckAnswer((1, 1, 1), (2, 2, 2)), - StopStream, - StartStream(), - AddData(input, 3, 6), - /* - (1, 1) (1, 1) - (2, 2) x (2, 2) = (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6) - (1, 6) (1, 6) - */ - CheckAnswer((3, 3, 3), (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6))) + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test(s"stream stream self join - $outputMode") { + val input = MemoryStream[Int] + val df = input.toDF() + val join = + df.select($"value" % 5 as "key", $"value").join( + df.select($"value" % 5 as "key", $"value"), "key") + + testStream(join, outputMode)( + AddData(input, 1, 2), + CheckAnswer((1, 1, 1), (2, 2, 2)), + StopStream, + StartStream(), + AddData(input, 3, 6), + /* + (1, 1) (1, 1) + (2, 2) x (2, 2) = (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6) + (1, 6) (1, 6) + */ + CheckAnswer((3, 3, 3), (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6))) + } } test("locality preferences of StateStoreAwareZippedRDD") { @@ -897,43 +888,45 @@ abstract class StreamingInnerJoinBase extends StreamingJoinSuite { ) } - test("joining non-nullable left join key with nullable right join key") { - val input1 = MemoryStream[Int] - val input2 = MemoryStream[JInteger] + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test(s"joining non-nullable left join key with nullable right join key - $outputMode") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[JInteger] - val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) - testStream(joined)( - AddData(input1, 1, 5), - AddData(input2, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), - CheckNewAnswer(Row(1, 1, 2, 3), Row(5, 5, 10, 15)) - ) - } + val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) + testStream(joined, outputMode)( + AddData(input1, 1, 5), + AddData(input2, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), + CheckNewAnswer(Row(1, 1, 2, 3), Row(5, 5, 10, 15)) + ) + } - test("joining nullable left join key with non-nullable right join key") { - val input1 = MemoryStream[JInteger] - val input2 = MemoryStream[Int] + test(s"joining nullable left join key with non-nullable right join key - $outputMode") { + val input1 = MemoryStream[JInteger] + val input2 = MemoryStream[Int] - val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) - testStream(joined)( - AddData(input1, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), - AddData(input2, 1, 5), - CheckNewAnswer(Row(1, 1, 2, 3), Row(5, 5, 10, 15)) - ) - } + val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) + testStream(joined, outputMode)( + AddData(input1, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), + AddData(input2, 1, 5), + CheckNewAnswer(Row(1, 1, 2, 3), Row(5, 5, 10, 15)) + ) + } - test("joining nullable left join key with nullable right join key") { - val input1 = MemoryStream[JInteger] - val input2 = MemoryStream[JInteger] + test(s"joining nullable left join key with nullable right join key - $outputMode") { + val input1 = MemoryStream[JInteger] + val input2 = MemoryStream[JInteger] - val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) - testStream(joined)( - AddData(input1, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), - AddData(input2, JInteger.valueOf(1), JInteger.valueOf(5), null), - CheckNewAnswer( - Row(JInteger.valueOf(1), JInteger.valueOf(1), JInteger.valueOf(2), JInteger.valueOf(3)), - Row(JInteger.valueOf(5), JInteger.valueOf(5), JInteger.valueOf(10), JInteger.valueOf(15)), - Row(null, null, null, null)) - ) + val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) + testStream(joined, outputMode)( + AddData(input1, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), + AddData(input2, JInteger.valueOf(1), JInteger.valueOf(5), null), + CheckNewAnswer( + Row(JInteger.valueOf(1), JInteger.valueOf(1), JInteger.valueOf(2), JInteger.valueOf(3)), + Row(JInteger.valueOf(5), JInteger.valueOf(5), JInteger.valueOf(10), JInteger.valueOf(15)), + Row(null, null, null, null)) + ) + } } testWithVirtualColumnFamilyJoins( @@ -1070,53 +1063,57 @@ abstract class StreamingInnerJoinBase extends StreamingJoinSuite { abstract class StreamingInnerJoinSuite extends StreamingInnerJoinBase { import testImplicits._ - test("stream stream inner join on windows - with watermark") { - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test(s"stream stream inner join on windows - with watermark - $outputMode") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] - val df1 = input1.toDF() - .select($"value" as "key", timestamp_seconds($"value") as "timestamp", - ($"value" * 2) as "leftValue") - .withWatermark("timestamp", "10 seconds") - .select($"key", window($"timestamp", "10 second"), $"leftValue") + val df1 = input1.toDF() + .select($"value" as "key", timestamp_seconds($"value") as "timestamp", + ($"value" * 2) as "leftValue") + .withWatermark("timestamp", "10 seconds") + .select($"key", window($"timestamp", "10 second"), $"leftValue") - val df2 = input2.toDF() - .select($"value" as "key", timestamp_seconds($"value") as "timestamp", - ($"value" * 3) as "rightValue") - .select($"key", window($"timestamp", "10 second"), $"rightValue") + val df2 = input2.toDF() + .select($"value" as "key", timestamp_seconds($"value") as "timestamp", + ($"value" * 3) as "rightValue") + .select($"key", window($"timestamp", "10 second"), $"rightValue") - val joined = df1.join(df2, Seq("key", "window")) - .select($"key", $"window.end".cast("long"), $"leftValue", $"rightValue") + val joined = df1.join(df2, Seq("key", "window")) + .select($"key", $"window.end".cast("long"), $"leftValue", $"rightValue") - testStream(joined)( - AddData(input1, 1), - CheckAnswer(), - assertNumStateRows(total = 1, updated = 1), + testStream(joined, outputMode)( + AddData(input1, 1), + CheckAnswer(), + assertNumStateRows(total = 1, updated = 1), - AddData(input2, 1), - CheckAnswer((1, 10, 2, 3)), - assertNumStateRows(total = 2, updated = 1), - StopStream, - StartStream(), + AddData(input2, 1), + CheckAnswer((1, 10, 2, 3)), + assertNumStateRows(total = 2, updated = 1), + StopStream, + StartStream(), - AddData(input1, 25), - CheckNewAnswer(), // watermark = 15, no-data-batch should remove 2 rows having window=[0,10] - assertNumStateRows(total = 1, updated = 1), + AddData(input1, 25), + // watermark = 15, no-data-batch should remove 2 rows + // having window=[0,10] + CheckNewAnswer(), + assertNumStateRows(total = 1, updated = 1), - AddData(input2, 25), - CheckNewAnswer((25, 30, 50, 75)), - assertNumStateRows(total = 2, updated = 1), - StopStream, - StartStream(), + AddData(input2, 25), + CheckNewAnswer((25, 30, 50, 75)), + assertNumStateRows(total = 2, updated = 1), + StopStream, + StartStream(), - AddData(input2, 1), - CheckNewAnswer(), // Should not join as < 15 removed - assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 + AddData(input2, 1), + CheckNewAnswer(), // Should not join as < 15 removed + assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 - AddData(input1, 5), - CheckNewAnswer(), // Same reason as above - assertNumStateRows(total = 2, updated = 0, droppedByWatermark = 1) - ) + AddData(input1, 5), + CheckNewAnswer(), // Same reason as above + assertNumStateRows(total = 2, updated = 0, droppedByWatermark = 1) + ) + } } test("SPARK-35896: metrics in StateOperatorProgress are output correctly") { @@ -2236,259 +2233,253 @@ abstract class StreamingLeftSemiJoinBase extends StreamingJoinSuite { import testImplicits._ - // Stream-stream non-outer join produces the same behavior between Append mode and Update mode. - // We only run a sanity test here rather than replicating the full Append mode test suite. - test("windowed left semi join with Update mode") { - withTempDir { checkpointDir => - val (leftInput, rightInput, joined) = setupWindowedJoin("left_semi") + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test(s"windowed left semi join - $outputMode") { + withTempDir { checkpointDir => + val (leftInput, rightInput, joined) = setupWindowedJoin("left_semi") - testStream(joined, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckNewAnswer(Row(3, 10, 6), Row(4, 10, 8), Row(5, 10, 10)), - MultiAddData(leftInput, 21)(rightInput, 22), - CheckNewAnswer() - ) + testStream(joined, outputMode)( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), + CheckNewAnswer(Row(3, 10, 6), Row(4, 10, 8), Row(5, 10, 10)), + // states + // left: 1, 2 (left 3, 4, 5 matched right in the same batch, emitted without storing) + // right: 3, 4, 5, 6, 7 + assertNumStateRows( + total = Seq(7), updated = Seq(7), + droppedByWatermark = Seq(0), removed = Some(Seq(0))), + MultiAddData(leftInput, 21)(rightInput, 22), + // Watermark = 11, should remove rows having window=[0,10]. + CheckNewAnswer(), + // states + // left: 21 + // right: 22 + // + // states evicted + // left: 1, 2 (below watermark) + // right: 3, 4, 5, 6, 7 (below watermark) + assertNumStateRows( + total = Seq(2), updated = Seq(2), + droppedByWatermark = Seq(0), removed = Some(Seq(7))), + StopStream, + // Restart join query from the same checkpoint + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(leftInput, 22), + CheckNewAnswer(Row(22, 30, 44)), + // Unlike inner/outer joins, given left input row matches with right input row, + // we don't buffer the matched left input row to the state store. + // + // states + // left: 21 + // right: 22 + assertNumStateRows( + total = Seq(2), updated = Seq(0), + droppedByWatermark = Seq(0), removed = Some(Seq(0))), + StopStream, + // Restart the query from the same checkpoint + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(leftInput, 1), + // Row not add as 1 < state key watermark = 12. + CheckNewAnswer(), + // states + // left: 21 + // right: 22 + assertNumStateRows( + total = Seq(2), updated = Seq(0), + droppedByWatermark = Seq(1), removed = Some(Seq(0))), + AddData(rightInput, 5), + // Row not add as 5 < state key watermark = 12. + CheckNewAnswer(), + // states + // left: 21 + // right: 22 + assertNumStateRows( + total = Seq(2), updated = Seq(0), + droppedByWatermark = Seq(1), removed = Some(Seq(0))) + ) + } } } - test("windowed left semi join") { - withTempDir { checkpointDir => - val (leftInput, rightInput, joined) = setupWindowedJoin("left_semi") + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test(s"left semi early state exclusion on left - $outputMode") { + val (leftInput, rightInput, joined) = setupWindowedJoinWithLeftCondition("left_semi") - testStream(joined)( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckNewAnswer(Row(3, 10, 6), Row(4, 10, 8), Row(5, 10, 10)), + testStream(joined, outputMode)( + MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), + // The left rows with leftValue <= 4 should not generate their semi join rows and + // not get added to the state. + CheckNewAnswer(Row(3, 10, 6)), // states - // left: 1, 2 (left 3, 4, 5 matched right in the same batch, emitted without storing) - // right: 3, 4, 5, 6, 7 + // left: (none - left 3 matched right in the same batch, emitted without storing) + // right: 3, 4, 5 assertNumStateRows( - total = Seq(7), updated = Seq(7), + total = Seq(3), updated = Seq(3), droppedByWatermark = Seq(0), removed = Some(Seq(0))), - MultiAddData(leftInput, 21)(rightInput, 22), - // Watermark = 11, should remove rows having window=[0,10]. + // We shouldn't get more semi join rows when the watermark advances. + MultiAddData(leftInput, 20)(rightInput, 21), CheckNewAnswer(), // states - // left: 21 - // right: 22 + // left: 20 + // right: 21 // // states evicted - // left: 1, 2 (below watermark) - // right: 3, 4, 5, 6, 7 (below watermark) + // right: 3, 4, 5 (below watermark) assertNumStateRows( total = Seq(2), updated = Seq(2), - droppedByWatermark = Seq(0), removed = Some(Seq(7))), - StopStream, - // Restart join query from the same checkpoint - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(leftInput, 22), - CheckNewAnswer(Row(22, 30, 44)), - // Unlike inner/outer joins, given left input row matches with right input row, - // we don't buffer the matched left input row to the state store. - // + droppedByWatermark = Seq(0), removed = Some(Seq(3))), + AddData(rightInput, 20), + CheckNewAnswer((20, 30, 40)), // states - // left: 21 - // right: 22 + // left: (empty -- 20 removed after matching right 20 via getJoinedRowsAndRemoveMatched) + // right: 21, 20 assertNumStateRows( - total = Seq(2), updated = Seq(0), + total = Seq(2), updated = Seq(1), + droppedByWatermark = Seq(0), removed = Some(Seq(1))) + ) + } + } + + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test(s"left semi early state exclusion on right - $outputMode") { + val (leftInput, rightInput, joined) = setupWindowedJoinWithRightCondition("left_semi") + + testStream(joined, outputMode)( + MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), + // The right rows with rightValue <= 7 should never be added to the state. + // The right row with rightValue = 9 > 7, hence joined and added to state. + CheckNewAnswer(Row(3, 10, 6)), + // states + // left: 4, 5 (left 3 matched right in the same batch, emitted without storing) + // right: 3 + assertNumStateRows( + total = Seq(3), updated = Seq(3), droppedByWatermark = Seq(0), removed = Some(Seq(0))), - StopStream, - // Restart the query from the same checkpoint - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(leftInput, 1), - // Row not add as 1 < state key watermark = 12. + // We shouldn't get more semi join rows when the watermark advances. + MultiAddData(leftInput, 20)(rightInput, 21), CheckNewAnswer(), // states - // left: 21 - // right: 22 + // left: 20 + // right: 21 + // + // states evicted + // left: 4, 5 (below watermark) + // right: 3 (below watermark) assertNumStateRows( - total = Seq(2), updated = Seq(0), - droppedByWatermark = Seq(1), removed = Some(Seq(0))), - AddData(rightInput, 5), - // Row not add as 5 < state key watermark = 12. - CheckNewAnswer(), + total = Seq(2), updated = Seq(2), + droppedByWatermark = Seq(0), removed = Some(Seq(3))), + AddData(rightInput, 20), + CheckNewAnswer((20, 30, 40)), // states - // left: 21 - // right: 22 + // left: (empty -- 20 removed after matching right 20 via getJoinedRowsAndRemoveMatched) + // right: 21, 20 assertNumStateRows( - total = Seq(2), updated = Seq(0), - droppedByWatermark = Seq(1), removed = Some(Seq(0))) + total = Seq(2), updated = Seq(1), + droppedByWatermark = Seq(0), removed = Some(Seq(1))) ) } } - test("left semi early state exclusion on left") { - val (leftInput, rightInput, joined) = setupWindowedJoinWithLeftCondition("left_semi") + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test(s"left semi join with watermark range condition - $outputMode") { + val (leftInput, rightInput, joined) = setupJoinWithRangeCondition("left_semi") - testStream(joined)( - MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), - // The left rows with leftValue <= 4 should not generate their semi join rows and - // not get added to the state. - CheckNewAnswer(Row(3, 10, 6)), - // states - // left: (none - left 3 matched right in the same batch, emitted without storing) - // right: 3, 4, 5 - assertNumStateRows( - total = Seq(3), updated = Seq(3), - droppedByWatermark = Seq(0), removed = Some(Seq(0))), - // We shouldn't get more semi join rows when the watermark advances. - MultiAddData(leftInput, 20)(rightInput, 21), - CheckNewAnswer(), - // states - // left: 20 - // right: 21 - // - // states evicted - // right: 3, 4, 5 (below watermark) - assertNumStateRows( - total = Seq(2), updated = Seq(2), - droppedByWatermark = Seq(0), removed = Some(Seq(3))), - AddData(rightInput, 20), - CheckNewAnswer((20, 30, 40)), - // states - // left: (empty -- 20 removed after matching right 20 via getJoinedRowsAndRemoveMatched) - // right: 21, 20 - assertNumStateRows( - total = Seq(2), updated = Seq(1), - droppedByWatermark = Seq(0), removed = Some(Seq(1))) - ) - } - - test("left semi early state exclusion on right") { - val (leftInput, rightInput, joined) = setupWindowedJoinWithRightCondition("left_semi") - - testStream(joined)( - MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), - // The right rows with rightValue <= 7 should never be added to the state. - // The right row with rightValue = 9 > 7, hence joined and added to state. - CheckNewAnswer(Row(3, 10, 6)), - // states - // left: 4, 5 (left 3 matched right in the same batch, emitted without storing) - // right: 3 - assertNumStateRows( - total = Seq(3), updated = Seq(3), - droppedByWatermark = Seq(0), removed = Some(Seq(0))), - // We shouldn't get more semi join rows when the watermark advances. - MultiAddData(leftInput, 20)(rightInput, 21), - CheckNewAnswer(), - // states - // left: 20 - // right: 21 - // - // states evicted - // left: 4, 5 (below watermark) - // right: 3 (below watermark) - assertNumStateRows( - total = Seq(2), updated = Seq(2), - droppedByWatermark = Seq(0), removed = Some(Seq(3))), - AddData(rightInput, 20), - CheckNewAnswer((20, 30, 40)), - // states - // left: (empty -- 20 removed after matching right 20 via getJoinedRowsAndRemoveMatched) - // right: 21, 20 - assertNumStateRows( - total = Seq(2), updated = Seq(1), - droppedByWatermark = Seq(0), removed = Some(Seq(1))) - ) - } - - test("left semi join with watermark range condition") { - val (leftInput, rightInput, joined) = setupJoinWithRangeCondition("left_semi") - - testStream(joined)( - AddData(leftInput, (1, 5), (3, 5)), - CheckNewAnswer(), - // states - // left: (1, 5), (3, 5) - // right: nothing - assertNumStateRows( - total = Seq(2), updated = Seq(2), - droppedByWatermark = Seq(0), removed = Some(Seq(0))), - AddData(rightInput, (1, 10), (2, 5)), - // Match left row in the state. Matched left row (1, 5) is immediately removed from state - // via getJoinedRowsAndRemoveMatched. - CheckNewAnswer((1, 5)), - // states - // left: (3, 5) -- (1, 5) removed after matching right (1, 10) - // right: (1, 10), (2, 5) - assertNumStateRows( - total = Seq(3), updated = Seq(2), - droppedByWatermark = Seq(0), removed = Some(Seq(1))), - AddData(rightInput, (1, 9)), - // No match as left row (1, 5) was already removed from state. - CheckNewAnswer(), - // states - // left: (3, 5) - // right: (1, 10), (2, 5), (1, 9) - assertNumStateRows( - total = Seq(4), updated = Seq(1), - droppedByWatermark = Seq(0), removed = Some(Seq(0))), - // Increase event time watermark to 20s by adding data with time = 30s on both inputs. - AddData(leftInput, (1, 7), (1, 30)), - CheckNewAnswer((1, 7)), - // states - // left: (3, 5), (1, 30) - // right: (1, 10), (2, 5), (1, 9) - assertNumStateRows( - total = Seq(5), updated = Seq(1), - droppedByWatermark = Seq(0), removed = Some(Seq(0))), - // Watermark = 30 - 10 = 20, no matched row. - AddData(rightInput, (0, 30)), - CheckNewAnswer(), - // states - // left: (1, 30) - // right: (0, 30) - // - // states evicted - // left: (3, 5) (below watermark = 20) - // right: (1, 10), (2, 5), (1, 9) (below watermark = 20) - assertNumStateRows( - total = Seq(2), updated = Seq(1), - droppedByWatermark = Seq(0), removed = Some(Seq(4))) - ) + testStream(joined, outputMode)( + AddData(leftInput, (1, 5), (3, 5)), + CheckNewAnswer(), + // states + // left: (1, 5), (3, 5) + // right: nothing + assertNumStateRows( + total = Seq(2), updated = Seq(2), + droppedByWatermark = Seq(0), removed = Some(Seq(0))), + AddData(rightInput, (1, 10), (2, 5)), + // Match left row in the state. Matched left row (1, 5) is immediately removed from state + // via getJoinedRowsAndRemoveMatched. + CheckNewAnswer((1, 5)), + // states + // left: (3, 5) -- (1, 5) removed after matching right (1, 10) + // right: (1, 10), (2, 5) + assertNumStateRows( + total = Seq(3), updated = Seq(2), + droppedByWatermark = Seq(0), removed = Some(Seq(1))), + AddData(rightInput, (1, 9)), + // No match as left row (1, 5) was already removed from state. + CheckNewAnswer(), + // states + // left: (3, 5) + // right: (1, 10), (2, 5), (1, 9) + assertNumStateRows( + total = Seq(4), updated = Seq(1), + droppedByWatermark = Seq(0), removed = Some(Seq(0))), + // Increase event time watermark to 20s by adding data with time = 30s on both inputs. + AddData(leftInput, (1, 7), (1, 30)), + CheckNewAnswer((1, 7)), + // states + // left: (3, 5), (1, 30) + // right: (1, 10), (2, 5), (1, 9) + assertNumStateRows( + total = Seq(5), updated = Seq(1), + droppedByWatermark = Seq(0), removed = Some(Seq(0))), + // Watermark = 30 - 10 = 20, no matched row. + AddData(rightInput, (0, 30)), + CheckNewAnswer(), + // states + // left: (1, 30) + // right: (0, 30) + // + // states evicted + // left: (3, 5) (below watermark = 20) + // right: (1, 10), (2, 5), (1, 9) (below watermark = 20) + assertNumStateRows( + total = Seq(2), updated = Seq(1), + droppedByWatermark = Seq(0), removed = Some(Seq(4))) + ) + } } - test("self left semi join") { - val (inputStream, query) = setupSelfJoin("left_semi") + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test(s"self left semi join - $outputMode") { + val (inputStream, query) = setupSelfJoin("left_semi") - testStream(query)( - AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), - CheckNewAnswer((2, 2), (4, 4)), - // batch 1 - global watermark = 0 - // states - // left: (none - left 2, 4 matched right in the same batch, emitted without storing) - // (left rows with value % 2 != 0 is filtered per [[PushPredicateThroughJoin]]) - // right: (2, 2L), (4, 4L) - // (right rows with value % 2 != 0 is filtered per [[PushPredicateThroughJoin]]) - assertNumStateRows( - total = Seq(2), updated = Seq(2), - droppedByWatermark = Seq(0), removed = Some(Seq(0))), - AddData(inputStream, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)), - CheckNewAnswer((6, 6), (8, 8), (10, 10)), - // batch 2 - global watermark = 5 - // states - // left: (none - left 6, 8, 10 matched right in the same batch, emitted without storing) - // right: (6, 6L), (8, 8L), (10, 10L) - // - // states evicted - // right: (2, 2L), (4, 4L) - assertNumStateRows( - total = Seq(3), updated = Seq(3), - droppedByWatermark = Seq(0), removed = Some(Seq(2))), - AddData(inputStream, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 15L)), - CheckNewAnswer((12, 12), (14, 14)), - // batch 3 - global watermark = 9 - // states - // left: (none - left 12, 14 matched right in the same batch, emitted without storing) - // right: (10, 10L), (12, 12L), (14, 14L) - // - // states evicted - // right: (6, 6L), (8, 8L) - assertNumStateRows( - total = Seq(3), updated = Seq(2), - droppedByWatermark = Seq(0), removed = Some(Seq(2))) - ) + testStream(query, outputMode)( + AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + CheckNewAnswer((2, 2), (4, 4)), + // batch 1 - global watermark = 0 + // states + // left: (none - left 2, 4 matched right in the same batch, emitted without storing) + // (left rows with value % 2 != 0 is filtered per [[PushPredicateThroughJoin]]) + // right: (2, 2L), (4, 4L) + // (right rows with value % 2 != 0 is filtered per [[PushPredicateThroughJoin]]) + assertNumStateRows( + total = Seq(2), updated = Seq(2), + droppedByWatermark = Seq(0), removed = Some(Seq(0))), + AddData(inputStream, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)), + CheckNewAnswer((6, 6), (8, 8), (10, 10)), + // batch 2 - global watermark = 5 + // states + // left: (none - left 6, 8, 10 matched right in the same batch, emitted without storing) + // right: (6, 6L), (8, 8L), (10, 10L) + // + // states evicted + // right: (2, 2L), (4, 4L) + assertNumStateRows( + total = Seq(3), updated = Seq(3), + droppedByWatermark = Seq(0), removed = Some(Seq(2))), + AddData(inputStream, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 15L)), + CheckNewAnswer((12, 12), (14, 14)), + // batch 3 - global watermark = 9 + // states + // left: (none - left 12, 14 matched right in the same batch, emitted without storing) + // right: (10, 10L), (12, 12L), (14, 14L) + // + // states evicted + // right: (6, 6L), (8, 8L) + assertNumStateRows( + total = Seq(3), updated = Seq(2), + droppedByWatermark = Seq(0), removed = Some(Seq(2))) + ) + } } } From 808550e122837b12cffe9500bbdc0840796352ae Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 9 Apr 2026 10:56:56 +0900 Subject: [PATCH 4/5] smaller code diff --- .../MultiStatefulOperatorsSuite.scala | 59 +- .../sql/streaming/StreamingJoinSuite.scala | 1019 ++++++++--------- 2 files changed, 533 insertions(+), 545 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala index 7fbeea180f86b..977078a649e98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala @@ -46,6 +46,13 @@ class MultiStatefulOperatorsSuite StateStore.stop() } + private def testWithAppendAndUpdate(testName: String)( + testBody: OutputMode => Any): Unit = { + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test(s"$testName - $outputMode")(testBody(outputMode)) + } + } + test("window agg -> window agg, append mode") { val inputData = MemoryStream[Int] @@ -934,33 +941,31 @@ class MultiStatefulOperatorsSuite ) } - Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => - test(s"dedup on both sides -> stream-stream inner join, ${outputMode} mode") { - val input1 = MemoryStream[Int] - val inputDF1 = input1.toDF() - .withColumnRenamed("value", "value1") - .withColumn("eventTime1", timestamp_seconds($"value1")) - .withWatermark("eventTime1", "10 seconds") - .dropDuplicates("value1", "eventTime1") - - val input2 = MemoryStream[Int] - val inputDF2 = input2.toDF() - .withColumnRenamed("value", "value2") - .withColumn("eventTime2", timestamp_seconds($"value2")) - .withWatermark("eventTime2", "10 seconds") - .dropDuplicates("value2", "eventTime2") - - val stream = inputDF1.join(inputDF2, expr("eventTime1 = eventTime2"), "inner") - .select($"value1", $"value2") - - testStream(stream, outputMode)( - MultiAddData(input1, 1, 2, 3, 1)(input2, 1, 2, 3, 2), - CheckNewAnswer((1, 1), (2, 2), (3, 3)), - - MultiAddData(input1, 1, 2, 4)(input2, 2, 3, 4), - CheckNewAnswer((4, 4)) - ) - } + testWithAppendAndUpdate("dedup on both sides -> stream-stream inner join") { outputMode => + val input1 = MemoryStream[Int] + val inputDF1 = input1.toDF() + .withColumnRenamed("value", "value1") + .withColumn("eventTime1", timestamp_seconds($"value1")) + .withWatermark("eventTime1", "10 seconds") + .dropDuplicates("value1", "eventTime1") + + val input2 = MemoryStream[Int] + val inputDF2 = input2.toDF() + .withColumnRenamed("value", "value2") + .withColumn("eventTime2", timestamp_seconds($"value2")) + .withWatermark("eventTime2", "10 seconds") + .dropDuplicates("value2", "eventTime2") + + val stream = inputDF1.join(inputDF2, expr("eventTime1 = eventTime2"), "inner") + .select($"value1", $"value2") + + testStream(stream, outputMode)( + MultiAddData(input1, 1, 2, 3, 1)(input2, 1, 2, 3, 2), + CheckNewAnswer((1, 1), (2, 2), (3, 3)), + + MultiAddData(input1, 1, 2, 4)(input2, 2, 3, 4), + CheckNewAnswer((4, 4)) + ) } test("stream-stream inner join -> window agg, update mode") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 4e23e64423b06..d6cbcff7430cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -117,6 +117,13 @@ abstract class StreamingJoinSuite } } + protected def testWithAppendAndUpdate(testName: String, testTags: Tag*)( + testBody: OutputMode => Any): Unit = { + Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => + test(s"$testName - $outputMode", testTags: _*)(testBody(outputMode)) + } + } + import testImplicits._ before { @@ -323,231 +330,222 @@ abstract class StreamingJoinSuite abstract class StreamingInnerJoinBase extends StreamingJoinSuite { import testImplicits._ - Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => - test(s"stream stream inner join on non-time column - $outputMode") { - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] + testWithAppendAndUpdate("stream stream inner join on non-time column") { outputMode => + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] - val df1 = input1.toDF().select($"value" as "key", ($"value" * 2) as "leftValue") - val df2 = input2.toDF().select($"value" as "key", ($"value" * 3) as "rightValue") - val joined = df1.join(df2, "key") + val df1 = input1.toDF().select($"value" as "key", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "key", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, "key") - testStream(joined, outputMode)( - AddData(input1, 1), - CheckAnswer(), - AddData(input2, 1, 10), // 1 arrived on input1 first, then input2, should join - CheckNewAnswer((1, 2, 3)), - AddData(input1, 10), // 10 arrived on input2 first, then input1, should join - CheckNewAnswer((10, 20, 30)), - AddData(input2, 1), // another 1 in input2 should join with 1 input1 - CheckNewAnswer((1, 2, 3)), - StopStream, - StartStream(), - AddData(input1, 1), // multiple 1s should be kept in state causing multiple (1, 2, 3) - CheckNewAnswer((1, 2, 3), (1, 2, 3)), - StopStream, - StartStream(), - AddData(input1, 100), - AddData(input2, 100), - CheckNewAnswer((100, 200, 300)) - ) - } + testStream(joined, outputMode)( + AddData(input1, 1), + CheckAnswer(), + AddData(input2, 1, 10), // 1 arrived on input1 first, then input2, should join + CheckNewAnswer((1, 2, 3)), + AddData(input1, 10), // 10 arrived on input2 first, then input1, should join + CheckNewAnswer((10, 20, 30)), + AddData(input2, 1), // another 1 in input2 should join with 1 input1 + CheckNewAnswer((1, 2, 3)), + StopStream, + StartStream(), + AddData(input1, 1), // multiple 1s should be kept in state causing multiple (1, 2, 3) + CheckNewAnswer((1, 2, 3), (1, 2, 3)), + StopStream, + StartStream(), + AddData(input1, 100), + AddData(input2, 100), + CheckNewAnswer((100, 200, 300)) + ) } - Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => - test(s"stream stream inner join on windows - without watermark - $outputMode") { - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] + testWithAppendAndUpdate("stream stream inner join on windows - without watermark") { outputMode => + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] - val df1 = input1.toDF() - .select($"value" as "key", timestamp_seconds($"value") as "timestamp", - ($"value" * 2) as "leftValue") - .select($"key", window($"timestamp", "10 second"), $"leftValue") + val df1 = input1.toDF() + .select($"value" as "key", timestamp_seconds($"value") as "timestamp", + ($"value" * 2) as "leftValue") + .select($"key", window($"timestamp", "10 second"), $"leftValue") - val df2 = input2.toDF() - .select($"value" as "key", timestamp_seconds($"value") as "timestamp", - ($"value" * 3) as "rightValue") - .select($"key", window($"timestamp", "10 second"), $"rightValue") + val df2 = input2.toDF() + .select($"value" as "key", timestamp_seconds($"value") as "timestamp", + ($"value" * 3) as "rightValue") + .select($"key", window($"timestamp", "10 second"), $"rightValue") - val joined = df1.join(df2, Seq("key", "window")) - .select($"key", $"window.end".cast("long"), $"leftValue", $"rightValue") + val joined = df1.join(df2, Seq("key", "window")) + .select($"key", $"window.end".cast("long"), $"leftValue", $"rightValue") - testStream(joined, outputMode)( - AddData(input1, 1), - CheckNewAnswer(), - AddData(input2, 1), - CheckNewAnswer((1, 10, 2, 3)), - StopStream, - StartStream(), - AddData(input1, 25), - CheckNewAnswer(), - StopStream, - StartStream(), - AddData(input2, 25), - CheckNewAnswer((25, 30, 50, 75)), - AddData(input1, 1), - CheckNewAnswer((1, 10, 2, 3)), // State for 1 still around as there is no watermark - StopStream, - StartStream(), - AddData(input1, 5), - CheckNewAnswer(), - AddData(input2, 5), - CheckNewAnswer((5, 10, 10, 15)) // No filter by any watermark - ) - } + testStream(joined, outputMode)( + AddData(input1, 1), + CheckNewAnswer(), + AddData(input2, 1), + CheckNewAnswer((1, 10, 2, 3)), + StopStream, + StartStream(), + AddData(input1, 25), + CheckNewAnswer(), + StopStream, + StartStream(), + AddData(input2, 25), + CheckNewAnswer((25, 30, 50, 75)), + AddData(input1, 1), + CheckNewAnswer((1, 10, 2, 3)), // State for 1 still around as there is no watermark + StopStream, + StartStream(), + AddData(input1, 5), + CheckNewAnswer(), + AddData(input2, 5), + CheckNewAnswer((5, 10, 10, 15)) // No filter by any watermark + ) } - Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => - test("stream stream inner join with time range - with watermark" + - s" - one side condition - $outputMode") { - import org.apache.spark.sql.functions._ + testWithAppendAndUpdate("stream stream inner join with time range - with watermark" + + " - one side condition") { outputMode => + import org.apache.spark.sql.functions._ - val leftInput = MemoryStream[(Int, Int)] - val rightInput = MemoryStream[(Int, Int)] - - val df1 = leftInput.toDF().toDF("leftKey", "time") - .select($"leftKey", timestamp_seconds($"time") as "leftTime", - ($"leftKey" * 2) as "leftValue") - .withWatermark("leftTime", "10 seconds") + val leftInput = MemoryStream[(Int, Int)] + val rightInput = MemoryStream[(Int, Int)] - val df2 = rightInput.toDF().toDF("rightKey", "time") - .select($"rightKey", timestamp_seconds($"time") as "rightTime", - ($"rightKey" * 3) as "rightValue") - .withWatermark("rightTime", "10 seconds") + val df1 = leftInput.toDF().toDF("leftKey", "time") + .select($"leftKey", timestamp_seconds($"time") as "leftTime", + ($"leftKey" * 2) as "leftValue") + .withWatermark("leftTime", "10 seconds") - val joined = - df1.join(df2, expr("leftKey = rightKey AND leftTime < rightTime - interval 5 seconds")) - .select($"leftKey", $"leftTime".cast("int"), $"rightTime".cast("int")) + val df2 = rightInput.toDF().toDF("rightKey", "time") + .select($"rightKey", timestamp_seconds($"time") as "rightTime", + ($"rightKey" * 3) as "rightValue") + .withWatermark("rightTime", "10 seconds") - testStream(joined, outputMode)( - AddData(leftInput, (1, 5)), - CheckAnswer(), - AddData(rightInput, (1, 11)), - CheckNewAnswer((1, 5, 11)), - AddData(rightInput, (1, 10)), - CheckNewAnswer(), // no match as leftTime 5 is not < rightTime 10 - 5 - assertNumStateRows(total = 3, updated = 3), + val joined = + df1.join(df2, expr("leftKey = rightKey AND leftTime < rightTime - interval 5 seconds")) + .select($"leftKey", $"leftTime".cast("int"), $"rightTime".cast("int")) - // Increase event time watermark to 20s by adding data with time = 30s on both inputs - AddData(leftInput, (1, 3), (1, 30)), - CheckNewAnswer((1, 3, 10), (1, 3, 11)), - assertNumStateRows(total = 5, updated = 2), - AddData(rightInput, (0, 30)), - CheckNewAnswer(), + testStream(joined, outputMode)( + AddData(leftInput, (1, 5)), + CheckAnswer(), + AddData(rightInput, (1, 11)), + CheckNewAnswer((1, 5, 11)), + AddData(rightInput, (1, 10)), + CheckNewAnswer(), // no match as leftTime 5 is not < rightTime 10 - 5 + assertNumStateRows(total = 3, updated = 3), + + // Increase event time watermark to 20s by adding data with time = 30s on both inputs + AddData(leftInput, (1, 3), (1, 30)), + CheckNewAnswer((1, 3, 10), (1, 3, 11)), + assertNumStateRows(total = 5, updated = 2), + AddData(rightInput, (0, 30)), + CheckNewAnswer(), - // event time watermark: max event time - 10 ==> 30 - 10 = 20 - // so left side going to only receive data where leftTime > 20 - // right side state constraint: 20 < leftTime < rightTime - 5 ==> rightTime > 25 - // right state where rightTime <= 25 will be cleared, (1, 11) and (1, 10) removed - assertNumStateRows(total = 4, updated = 1), - - // New data to right input should match with left side (1, 3) and (1, 5), as left state - // should not be cleared. But rows rightTime <= 20 should be filtered due to event time - // watermark and state rows with rightTime <= 25 should be removed from state. - // (1, 20) ==> filtered by event time watermark = 20 - // (1, 21) ==> passed filter, matched with left (1, 3) and (1, 5), not added to state - // as 21 < state watermark = 25 - // (1, 28) ==> passed filter, matched with left (1, 3) and (1, 5), added to state - AddData(rightInput, (1, 20), (1, 21), (1, 28)), - CheckNewAnswer((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), - assertNumStateRows(total = 5, updated = 1, droppedByWatermark = 1), - - // New data to left input with leftTime <= 20 should be filtered due to event time watermark - AddData(leftInput, (1, 20), (1, 21)), - CheckNewAnswer((1, 21, 28)), - assertNumStateRows(total = 6, updated = 1, droppedByWatermark = 1) - ) - } + // event time watermark: max event time - 10 ==> 30 - 10 = 20 + // so left side going to only receive data where leftTime > 20 + // right side state constraint: 20 < leftTime < rightTime - 5 ==> rightTime > 25 + // right state where rightTime <= 25 will be cleared, (1, 11) and (1, 10) removed + assertNumStateRows(total = 4, updated = 1), + + // New data to right input should match with left side (1, 3) and (1, 5), as left state should + // not be cleared. But rows rightTime <= 20 should be filtered due to event time watermark and + // state rows with rightTime <= 25 should be removed from state. + // (1, 20) ==> filtered by event time watermark = 20 + // (1, 21) ==> passed filter, matched with left (1, 3) and (1, 5), not added to state + // as 21 < state watermark = 25 + // (1, 28) ==> passed filter, matched with left (1, 3) and (1, 5), added to state + AddData(rightInput, (1, 20), (1, 21), (1, 28)), + CheckNewAnswer((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), + assertNumStateRows(total = 5, updated = 1, droppedByWatermark = 1), + + // New data to left input with leftTime <= 20 should be filtered due to event time watermark + AddData(leftInput, (1, 20), (1, 21)), + CheckNewAnswer((1, 21, 28)), + assertNumStateRows(total = 6, updated = 1, droppedByWatermark = 1) + ) } - Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => - test("stream stream inner join with time range - with watermark" + - s" - two side conditions - $outputMode") { - import org.apache.spark.sql.functions._ + testWithAppendAndUpdate("stream stream inner join with time range - with watermark" + + " - two side conditions") { outputMode => + import org.apache.spark.sql.functions._ - val leftInput = MemoryStream[(Int, Int)] - val rightInput = MemoryStream[(Int, Int)] + val leftInput = MemoryStream[(Int, Int)] + val rightInput = MemoryStream[(Int, Int)] - val df1 = leftInput.toDF().toDF("leftKey", "time") - .select($"leftKey", timestamp_seconds($"time") as "leftTime", - ($"leftKey" * 2) as "leftValue") - .withWatermark("leftTime", "20 seconds") + val df1 = leftInput.toDF().toDF("leftKey", "time") + .select($"leftKey", timestamp_seconds($"time") as "leftTime", + ($"leftKey" * 2) as "leftValue") + .withWatermark("leftTime", "20 seconds") - val df2 = rightInput.toDF().toDF("rightKey", "time") - .select($"rightKey", timestamp_seconds($"time") as "rightTime", - ($"rightKey" * 3) as "rightValue") - .withWatermark("rightTime", "30 seconds") + val df2 = rightInput.toDF().toDF("rightKey", "time") + .select($"rightKey", timestamp_seconds($"time") as "rightTime", + ($"rightKey" * 3) as "rightValue") + .withWatermark("rightTime", "30 seconds") + + val condition = expr( + "leftKey = rightKey AND " + + "leftTime BETWEEN rightTime - interval 10 seconds AND rightTime + interval 5 seconds") + + // This translates to leftTime <= rightTime + 5 seconds AND leftTime >= rightTime - 10 seconds + // So given leftTime, rightTime has to be BETWEEN leftTime - 5 seconds AND leftTime + 10 seconds + // + // =============== * ======================== * ============================== * ==> leftTime + // | | | + // |<---- 5s -->|<------ 10s ------>| |<------ 10s ------>|<---- 5s -->| + // | | | + // == * ============================== * =========>============== * ===============> rightTime + // + // E.g. + // if rightTime = 60, then it matches only leftTime = [50, 65] + // if leftTime = 20, then it match only with rightTime = [15, 30] + // + // State value predicates + // left side: + // values allowed: leftTime >= rightTime - 10s ==> leftTime > eventTimeWatermark - 10 + // drop state where leftTime < eventTime - 10 + // right side: + // values allowed: rightTime >= leftTime - 5s ==> rightTime > eventTimeWatermark - 5 + // drop state where rightTime < eventTime - 5 - val condition = expr( - "leftKey = rightKey AND " + - "leftTime BETWEEN rightTime - interval 10 seconds AND rightTime + interval 5 seconds") + val joined = + df1.join(df2, condition).select($"leftKey", $"leftTime".cast("int"), + $"rightTime".cast("int")) - // This translates to leftTime <= rightTime + 5 seconds AND - // leftTime >= rightTime - 10 seconds. So given leftTime, rightTime has to be - // BETWEEN leftTime - 5 seconds AND leftTime + 10 seconds - // - // ============ * ==================== * ======================== * ==> leftTime - // | | | - // |<--- 5s --->|<----- 10s ----->| |<----- 10s ----->|<- 5s->| - // | | | - // * ============================= * =================== * ========> rightTime - // - // E.g. - // if rightTime = 60, then it matches only leftTime = [50, 65] - // if leftTime = 20, then it match only with rightTime = [15, 30] - // - // State value predicates - // left side: - // values allowed: leftTime >= rightTime - 10s ==> leftTime > eventTimeWatermark - 10 - // drop state where leftTime < eventTime - 10 - // right side: - // values allowed: rightTime >= leftTime - 5s ==> rightTime > eventTimeWatermark - 5 - // drop state where rightTime < eventTime - 5 - - val joined = - df1.join(df2, condition).select($"leftKey", $"leftTime".cast("int"), - $"rightTime".cast("int")) + testStream(joined, outputMode)( + // If leftTime = 20, then it match only with rightTime = [15, 30] + AddData(leftInput, (1, 20)), + CheckAnswer(), + AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), + CheckNewAnswer((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), + assertNumStateRows(total = 7, updated = 7), - testStream(joined, outputMode)( - // If leftTime = 20, then it match only with rightTime = [15, 30] - AddData(leftInput, (1, 20)), - CheckAnswer(), - AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), - CheckNewAnswer((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), - assertNumStateRows(total = 7, updated = 7), - - // If rightTime = 60, then it matches only leftTime = [50, 65] - AddData(rightInput, (1, 60)), - CheckNewAnswer(), // matches with nothing on the left - AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), - CheckNewAnswer((1, 50, 60), (1, 65, 60)), - - // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 - // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) - // Should drop < 20 from left, i.e., none - // Right state value watermark = 30 - 5 = slightly less than 25 (since condition has <=) - // Should drop < 25 from the right, i.e., 14 and 15 - assertNumStateRows(total = 10, updated = 5), // 12 - 2 removed - - AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to state - CheckNewAnswer((1, 31, 26), (1, 31, 30), (1, 31, 31)), - assertNumStateRows(total = 11, updated = 1, droppedByWatermark = 1), // only 31 added - - // Advance the watermark - AddData(rightInput, (1, 80)), - CheckNewAnswer(), - // Event time watermark = min(left: 66 - delay 20 = 46, right: 80 - delay 30 = 50) = 46 - // Left state value watermark = 46 - 10 = slightly less than 36 (since condition has <=) - // Should drop < 36 from left, i.e., 20, 31 (30 was not added) - // Right state value watermark = 46 - 5 = slightly less than 41 (since condition has <=) - // Should drop < 41 from the right, i.e., 25, 26, 30, 31 - assertNumStateRows(total = 6, updated = 1), // 12 - 6 removed - - AddData(rightInput, (1, 46), (1, 50)), // 46 should not be processed or added to state - CheckNewAnswer((1, 49, 50), (1, 50, 50)), - assertNumStateRows(total = 7, updated = 1, droppedByWatermark = 1) // 50 added - ) - } + // If rightTime = 60, then it matches only leftTime = [50, 65] + AddData(rightInput, (1, 60)), + CheckNewAnswer(), // matches with nothing on the left + AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), + CheckNewAnswer((1, 50, 60), (1, 65, 60)), + + // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 + // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) + // Should drop < 20 from left, i.e., none + // Right state value watermark = 30 - 5 = slightly less than 25 (since condition has <=) + // Should drop < 25 from the right, i.e., 14 and 15 + assertNumStateRows(total = 10, updated = 5), // 12 - 2 removed + + AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to state + CheckNewAnswer((1, 31, 26), (1, 31, 30), (1, 31, 31)), + assertNumStateRows(total = 11, updated = 1, droppedByWatermark = 1), // only 31 added + + // Advance the watermark + AddData(rightInput, (1, 80)), + CheckNewAnswer(), + // Event time watermark = min(left: 66 - delay 20 = 46, right: 80 - delay 30 = 50) = 46 + // Left state value watermark = 46 - 10 = slightly less than 36 (since condition has <=) + // Should drop < 36 from left, i.e., 20, 31 (30 was not added) + // Right state value watermark = 46 - 5 = slightly less than 41 (since condition has <=) + // Should drop < 41 from the right, i.e., 25, 26, 30, 31 + assertNumStateRows(total = 6, updated = 1), // 12 - 6 removed + + AddData(rightInput, (1, 46), (1, 50)), // 46 should not be processed or added to state + CheckNewAnswer((1, 49, 50), (1, 50, 50)), + assertNumStateRows(total = 7, updated = 1, droppedByWatermark = 1) // 50 added + ) } testQuietly("stream stream inner join without equality predicate") { @@ -567,27 +565,25 @@ abstract class StreamingInnerJoinBase extends StreamingJoinSuite { assert(e.toString.contains("Stream-stream join without equality predicate is not supported")) } - Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => - test(s"stream stream self join - $outputMode") { - val input = MemoryStream[Int] - val df = input.toDF() - val join = - df.select($"value" % 5 as "key", $"value").join( - df.select($"value" % 5 as "key", $"value"), "key") - - testStream(join, outputMode)( - AddData(input, 1, 2), - CheckAnswer((1, 1, 1), (2, 2, 2)), - StopStream, - StartStream(), - AddData(input, 3, 6), - /* - (1, 1) (1, 1) - (2, 2) x (2, 2) = (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6) - (1, 6) (1, 6) - */ - CheckAnswer((3, 3, 3), (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6))) - } + testWithAppendAndUpdate("stream stream self join") { outputMode => + val input = MemoryStream[Int] + val df = input.toDF() + val join = + df.select($"value" % 5 as "key", $"value").join( + df.select($"value" % 5 as "key", $"value"), "key") + + testStream(join, outputMode)( + AddData(input, 1, 2), + CheckAnswer((1, 1, 1), (2, 2, 2)), + StopStream, + StartStream(), + AddData(input, 3, 6), + /* + (1, 1) (1, 1) + (2, 2) x (2, 2) = (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6) + (1, 6) (1, 6) + */ + CheckAnswer((3, 3, 3), (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6))) } test("locality preferences of StateStoreAwareZippedRDD") { @@ -888,45 +884,46 @@ abstract class StreamingInnerJoinBase extends StreamingJoinSuite { ) } - Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => - test(s"joining non-nullable left join key with nullable right join key - $outputMode") { - val input1 = MemoryStream[Int] - val input2 = MemoryStream[JInteger] + testWithAppendAndUpdate("joining non-nullable left join key with nullable right join key") { + outputMode => + val input1 = MemoryStream[Int] + val input2 = MemoryStream[JInteger] - val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) - testStream(joined, outputMode)( - AddData(input1, 1, 5), - AddData(input2, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), - CheckNewAnswer(Row(1, 1, 2, 3), Row(5, 5, 10, 15)) - ) - } + val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) + testStream(joined, outputMode)( + AddData(input1, 1, 5), + AddData(input2, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), + CheckNewAnswer(Row(1, 1, 2, 3), Row(5, 5, 10, 15)) + ) + } - test(s"joining nullable left join key with non-nullable right join key - $outputMode") { - val input1 = MemoryStream[JInteger] - val input2 = MemoryStream[Int] + testWithAppendAndUpdate("joining nullable left join key with non-nullable right join key") { + outputMode => + val input1 = MemoryStream[JInteger] + val input2 = MemoryStream[Int] - val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) - testStream(joined, outputMode)( - AddData(input1, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), - AddData(input2, 1, 5), - CheckNewAnswer(Row(1, 1, 2, 3), Row(5, 5, 10, 15)) - ) - } + val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) + testStream(joined, outputMode)( + AddData(input1, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), + AddData(input2, 1, 5), + CheckNewAnswer(Row(1, 1, 2, 3), Row(5, 5, 10, 15)) + ) + } - test(s"joining nullable left join key with nullable right join key - $outputMode") { - val input1 = MemoryStream[JInteger] - val input2 = MemoryStream[JInteger] + testWithAppendAndUpdate("joining nullable left join key with nullable right join key") { + outputMode => + val input1 = MemoryStream[JInteger] + val input2 = MemoryStream[JInteger] - val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) - testStream(joined, outputMode)( - AddData(input1, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), - AddData(input2, JInteger.valueOf(1), JInteger.valueOf(5), null), - CheckNewAnswer( - Row(JInteger.valueOf(1), JInteger.valueOf(1), JInteger.valueOf(2), JInteger.valueOf(3)), - Row(JInteger.valueOf(5), JInteger.valueOf(5), JInteger.valueOf(10), JInteger.valueOf(15)), - Row(null, null, null, null)) - ) - } + val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) + testStream(joined, outputMode)( + AddData(input1, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), + AddData(input2, JInteger.valueOf(1), JInteger.valueOf(5), null), + CheckNewAnswer( + Row(JInteger.valueOf(1), JInteger.valueOf(1), JInteger.valueOf(2), JInteger.valueOf(3)), + Row(JInteger.valueOf(5), JInteger.valueOf(5), JInteger.valueOf(10), JInteger.valueOf(15)), + Row(null, null, null, null)) + ) } testWithVirtualColumnFamilyJoins( @@ -1063,57 +1060,53 @@ abstract class StreamingInnerJoinBase extends StreamingJoinSuite { abstract class StreamingInnerJoinSuite extends StreamingInnerJoinBase { import testImplicits._ - Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => - test(s"stream stream inner join on windows - with watermark - $outputMode") { - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] + testWithAppendAndUpdate("stream stream inner join on windows - with watermark") { outputMode => + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] - val df1 = input1.toDF() - .select($"value" as "key", timestamp_seconds($"value") as "timestamp", - ($"value" * 2) as "leftValue") - .withWatermark("timestamp", "10 seconds") - .select($"key", window($"timestamp", "10 second"), $"leftValue") + val df1 = input1.toDF() + .select($"value" as "key", timestamp_seconds($"value") as "timestamp", + ($"value" * 2) as "leftValue") + .withWatermark("timestamp", "10 seconds") + .select($"key", window($"timestamp", "10 second"), $"leftValue") - val df2 = input2.toDF() - .select($"value" as "key", timestamp_seconds($"value") as "timestamp", - ($"value" * 3) as "rightValue") - .select($"key", window($"timestamp", "10 second"), $"rightValue") + val df2 = input2.toDF() + .select($"value" as "key", timestamp_seconds($"value") as "timestamp", + ($"value" * 3) as "rightValue") + .select($"key", window($"timestamp", "10 second"), $"rightValue") - val joined = df1.join(df2, Seq("key", "window")) - .select($"key", $"window.end".cast("long"), $"leftValue", $"rightValue") + val joined = df1.join(df2, Seq("key", "window")) + .select($"key", $"window.end".cast("long"), $"leftValue", $"rightValue") - testStream(joined, outputMode)( - AddData(input1, 1), - CheckAnswer(), - assertNumStateRows(total = 1, updated = 1), + testStream(joined, outputMode)( + AddData(input1, 1), + CheckAnswer(), + assertNumStateRows(total = 1, updated = 1), - AddData(input2, 1), - CheckAnswer((1, 10, 2, 3)), - assertNumStateRows(total = 2, updated = 1), - StopStream, - StartStream(), + AddData(input2, 1), + CheckAnswer((1, 10, 2, 3)), + assertNumStateRows(total = 2, updated = 1), + StopStream, + StartStream(), - AddData(input1, 25), - // watermark = 15, no-data-batch should remove 2 rows - // having window=[0,10] - CheckNewAnswer(), - assertNumStateRows(total = 1, updated = 1), + AddData(input1, 25), + CheckNewAnswer(), // watermark = 15, no-data-batch should remove 2 rows having window=[0,10] + assertNumStateRows(total = 1, updated = 1), - AddData(input2, 25), - CheckNewAnswer((25, 30, 50, 75)), - assertNumStateRows(total = 2, updated = 1), - StopStream, - StartStream(), + AddData(input2, 25), + CheckNewAnswer((25, 30, 50, 75)), + assertNumStateRows(total = 2, updated = 1), + StopStream, + StartStream(), - AddData(input2, 1), - CheckNewAnswer(), // Should not join as < 15 removed - assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 + AddData(input2, 1), + CheckNewAnswer(), // Should not join as < 15 removed + assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 - AddData(input1, 5), - CheckNewAnswer(), // Same reason as above - assertNumStateRows(total = 2, updated = 0, droppedByWatermark = 1) - ) - } + AddData(input1, 5), + CheckNewAnswer(), // Same reason as above + assertNumStateRows(total = 2, updated = 0, droppedByWatermark = 1) + ) } test("SPARK-35896: metrics in StateOperatorProgress are output correctly") { @@ -2233,253 +2226,243 @@ abstract class StreamingLeftSemiJoinBase extends StreamingJoinSuite { import testImplicits._ - Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => - test(s"windowed left semi join - $outputMode") { - withTempDir { checkpointDir => - val (leftInput, rightInput, joined) = setupWindowedJoin("left_semi") - - testStream(joined, outputMode)( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckNewAnswer(Row(3, 10, 6), Row(4, 10, 8), Row(5, 10, 10)), - // states - // left: 1, 2 (left 3, 4, 5 matched right in the same batch, emitted without storing) - // right: 3, 4, 5, 6, 7 - assertNumStateRows( - total = Seq(7), updated = Seq(7), - droppedByWatermark = Seq(0), removed = Some(Seq(0))), - MultiAddData(leftInput, 21)(rightInput, 22), - // Watermark = 11, should remove rows having window=[0,10]. - CheckNewAnswer(), - // states - // left: 21 - // right: 22 - // - // states evicted - // left: 1, 2 (below watermark) - // right: 3, 4, 5, 6, 7 (below watermark) - assertNumStateRows( - total = Seq(2), updated = Seq(2), - droppedByWatermark = Seq(0), removed = Some(Seq(7))), - StopStream, - // Restart join query from the same checkpoint - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(leftInput, 22), - CheckNewAnswer(Row(22, 30, 44)), - // Unlike inner/outer joins, given left input row matches with right input row, - // we don't buffer the matched left input row to the state store. - // - // states - // left: 21 - // right: 22 - assertNumStateRows( - total = Seq(2), updated = Seq(0), - droppedByWatermark = Seq(0), removed = Some(Seq(0))), - StopStream, - // Restart the query from the same checkpoint - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(leftInput, 1), - // Row not add as 1 < state key watermark = 12. - CheckNewAnswer(), - // states - // left: 21 - // right: 22 - assertNumStateRows( - total = Seq(2), updated = Seq(0), - droppedByWatermark = Seq(1), removed = Some(Seq(0))), - AddData(rightInput, 5), - // Row not add as 5 < state key watermark = 12. - CheckNewAnswer(), - // states - // left: 21 - // right: 22 - assertNumStateRows( - total = Seq(2), updated = Seq(0), - droppedByWatermark = Seq(1), removed = Some(Seq(0))) - ) - } - } - } - - Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => - test(s"left semi early state exclusion on left - $outputMode") { - val (leftInput, rightInput, joined) = setupWindowedJoinWithLeftCondition("left_semi") + testWithAppendAndUpdate("windowed left semi join") { outputMode => + withTempDir { checkpointDir => + val (leftInput, rightInput, joined) = setupWindowedJoin("left_semi") testStream(joined, outputMode)( - MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), - // The left rows with leftValue <= 4 should not generate their semi join rows and - // not get added to the state. - CheckNewAnswer(Row(3, 10, 6)), + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), + CheckNewAnswer(Row(3, 10, 6), Row(4, 10, 8), Row(5, 10, 10)), // states - // left: (none - left 3 matched right in the same batch, emitted without storing) - // right: 3, 4, 5 + // left: 1, 2 (left 3, 4, 5 matched right in the same batch, emitted without storing) + // right: 3, 4, 5, 6, 7 assertNumStateRows( - total = Seq(3), updated = Seq(3), + total = Seq(7), updated = Seq(7), droppedByWatermark = Seq(0), removed = Some(Seq(0))), - // We shouldn't get more semi join rows when the watermark advances. - MultiAddData(leftInput, 20)(rightInput, 21), + MultiAddData(leftInput, 21)(rightInput, 22), + // Watermark = 11, should remove rows having window=[0,10]. CheckNewAnswer(), // states - // left: 20 - // right: 21 + // left: 21 + // right: 22 // // states evicted - // right: 3, 4, 5 (below watermark) + // left: 1, 2 (below watermark) + // right: 3, 4, 5, 6, 7 (below watermark) assertNumStateRows( total = Seq(2), updated = Seq(2), - droppedByWatermark = Seq(0), removed = Some(Seq(3))), - AddData(rightInput, 20), - CheckNewAnswer((20, 30, 40)), - // states - // left: (empty -- 20 removed after matching right 20 via getJoinedRowsAndRemoveMatched) - // right: 21, 20 - assertNumStateRows( - total = Seq(2), updated = Seq(1), - droppedByWatermark = Seq(0), removed = Some(Seq(1))) - ) - } - } - - Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => - test(s"left semi early state exclusion on right - $outputMode") { - val (leftInput, rightInput, joined) = setupWindowedJoinWithRightCondition("left_semi") - - testStream(joined, outputMode)( - MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), - // The right rows with rightValue <= 7 should never be added to the state. - // The right row with rightValue = 9 > 7, hence joined and added to state. - CheckNewAnswer(Row(3, 10, 6)), + droppedByWatermark = Seq(0), removed = Some(Seq(7))), + StopStream, + // Restart join query from the same checkpoint + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(leftInput, 22), + CheckNewAnswer(Row(22, 30, 44)), + // Unlike inner/outer joins, given left input row matches with right input row, + // we don't buffer the matched left input row to the state store. + // // states - // left: 4, 5 (left 3 matched right in the same batch, emitted without storing) - // right: 3 + // left: 21 + // right: 22 assertNumStateRows( - total = Seq(3), updated = Seq(3), + total = Seq(2), updated = Seq(0), droppedByWatermark = Seq(0), removed = Some(Seq(0))), - // We shouldn't get more semi join rows when the watermark advances. - MultiAddData(leftInput, 20)(rightInput, 21), + StopStream, + // Restart the query from the same checkpoint + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(leftInput, 1), + // Row not add as 1 < state key watermark = 12. CheckNewAnswer(), // states - // left: 20 - // right: 21 - // - // states evicted - // left: 4, 5 (below watermark) - // right: 3 (below watermark) + // left: 21 + // right: 22 assertNumStateRows( - total = Seq(2), updated = Seq(2), - droppedByWatermark = Seq(0), removed = Some(Seq(3))), - AddData(rightInput, 20), - CheckNewAnswer((20, 30, 40)), + total = Seq(2), updated = Seq(0), + droppedByWatermark = Seq(1), removed = Some(Seq(0))), + AddData(rightInput, 5), + // Row not add as 5 < state key watermark = 12. + CheckNewAnswer(), // states - // left: (empty -- 20 removed after matching right 20 via getJoinedRowsAndRemoveMatched) - // right: 21, 20 + // left: 21 + // right: 22 assertNumStateRows( - total = Seq(2), updated = Seq(1), - droppedByWatermark = Seq(0), removed = Some(Seq(1))) + total = Seq(2), updated = Seq(0), + droppedByWatermark = Seq(1), removed = Some(Seq(0))) ) } } - Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => - test(s"left semi join with watermark range condition - $outputMode") { - val (leftInput, rightInput, joined) = setupJoinWithRangeCondition("left_semi") + testWithAppendAndUpdate("left semi early state exclusion on left") { outputMode => + val (leftInput, rightInput, joined) = setupWindowedJoinWithLeftCondition("left_semi") - testStream(joined, outputMode)( - AddData(leftInput, (1, 5), (3, 5)), - CheckNewAnswer(), - // states - // left: (1, 5), (3, 5) - // right: nothing - assertNumStateRows( - total = Seq(2), updated = Seq(2), - droppedByWatermark = Seq(0), removed = Some(Seq(0))), - AddData(rightInput, (1, 10), (2, 5)), - // Match left row in the state. Matched left row (1, 5) is immediately removed from state - // via getJoinedRowsAndRemoveMatched. - CheckNewAnswer((1, 5)), - // states - // left: (3, 5) -- (1, 5) removed after matching right (1, 10) - // right: (1, 10), (2, 5) - assertNumStateRows( - total = Seq(3), updated = Seq(2), - droppedByWatermark = Seq(0), removed = Some(Seq(1))), - AddData(rightInput, (1, 9)), - // No match as left row (1, 5) was already removed from state. - CheckNewAnswer(), - // states - // left: (3, 5) - // right: (1, 10), (2, 5), (1, 9) - assertNumStateRows( - total = Seq(4), updated = Seq(1), - droppedByWatermark = Seq(0), removed = Some(Seq(0))), - // Increase event time watermark to 20s by adding data with time = 30s on both inputs. - AddData(leftInput, (1, 7), (1, 30)), - CheckNewAnswer((1, 7)), - // states - // left: (3, 5), (1, 30) - // right: (1, 10), (2, 5), (1, 9) - assertNumStateRows( - total = Seq(5), updated = Seq(1), - droppedByWatermark = Seq(0), removed = Some(Seq(0))), - // Watermark = 30 - 10 = 20, no matched row. - AddData(rightInput, (0, 30)), - CheckNewAnswer(), - // states - // left: (1, 30) - // right: (0, 30) - // - // states evicted - // left: (3, 5) (below watermark = 20) - // right: (1, 10), (2, 5), (1, 9) (below watermark = 20) - assertNumStateRows( - total = Seq(2), updated = Seq(1), - droppedByWatermark = Seq(0), removed = Some(Seq(4))) - ) - } + testStream(joined, outputMode)( + MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), + // The left rows with leftValue <= 4 should not generate their semi join rows and + // not get added to the state. + CheckNewAnswer(Row(3, 10, 6)), + // states + // left: (none - left 3 matched right in the same batch, emitted without storing) + // right: 3, 4, 5 + assertNumStateRows( + total = Seq(3), updated = Seq(3), + droppedByWatermark = Seq(0), removed = Some(Seq(0))), + // We shouldn't get more semi join rows when the watermark advances. + MultiAddData(leftInput, 20)(rightInput, 21), + CheckNewAnswer(), + // states + // left: 20 + // right: 21 + // + // states evicted + // right: 3, 4, 5 (below watermark) + assertNumStateRows( + total = Seq(2), updated = Seq(2), + droppedByWatermark = Seq(0), removed = Some(Seq(3))), + AddData(rightInput, 20), + CheckNewAnswer((20, 30, 40)), + // states + // left: (empty -- 20 removed after matching right 20 via getJoinedRowsAndRemoveMatched) + // right: 21, 20 + assertNumStateRows( + total = Seq(2), updated = Seq(1), + droppedByWatermark = Seq(0), removed = Some(Seq(1))) + ) } - Seq(OutputMode.Append(), OutputMode.Update()).foreach { outputMode => - test(s"self left semi join - $outputMode") { - val (inputStream, query) = setupSelfJoin("left_semi") + testWithAppendAndUpdate("left semi early state exclusion on right") { outputMode => + val (leftInput, rightInput, joined) = setupWindowedJoinWithRightCondition("left_semi") - testStream(query, outputMode)( - AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), - CheckNewAnswer((2, 2), (4, 4)), - // batch 1 - global watermark = 0 - // states - // left: (none - left 2, 4 matched right in the same batch, emitted without storing) - // (left rows with value % 2 != 0 is filtered per [[PushPredicateThroughJoin]]) - // right: (2, 2L), (4, 4L) - // (right rows with value % 2 != 0 is filtered per [[PushPredicateThroughJoin]]) - assertNumStateRows( - total = Seq(2), updated = Seq(2), - droppedByWatermark = Seq(0), removed = Some(Seq(0))), - AddData(inputStream, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)), - CheckNewAnswer((6, 6), (8, 8), (10, 10)), - // batch 2 - global watermark = 5 - // states - // left: (none - left 6, 8, 10 matched right in the same batch, emitted without storing) - // right: (6, 6L), (8, 8L), (10, 10L) - // - // states evicted - // right: (2, 2L), (4, 4L) - assertNumStateRows( - total = Seq(3), updated = Seq(3), - droppedByWatermark = Seq(0), removed = Some(Seq(2))), - AddData(inputStream, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 15L)), - CheckNewAnswer((12, 12), (14, 14)), - // batch 3 - global watermark = 9 - // states - // left: (none - left 12, 14 matched right in the same batch, emitted without storing) - // right: (10, 10L), (12, 12L), (14, 14L) - // - // states evicted - // right: (6, 6L), (8, 8L) - assertNumStateRows( - total = Seq(3), updated = Seq(2), - droppedByWatermark = Seq(0), removed = Some(Seq(2))) - ) - } + testStream(joined, outputMode)( + MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), + // The right rows with rightValue <= 7 should never be added to the state. + // The right row with rightValue = 9 > 7, hence joined and added to state. + CheckNewAnswer(Row(3, 10, 6)), + // states + // left: 4, 5 (left 3 matched right in the same batch, emitted without storing) + // right: 3 + assertNumStateRows( + total = Seq(3), updated = Seq(3), + droppedByWatermark = Seq(0), removed = Some(Seq(0))), + // We shouldn't get more semi join rows when the watermark advances. + MultiAddData(leftInput, 20)(rightInput, 21), + CheckNewAnswer(), + // states + // left: 20 + // right: 21 + // + // states evicted + // left: 4, 5 (below watermark) + // right: 3 (below watermark) + assertNumStateRows( + total = Seq(2), updated = Seq(2), + droppedByWatermark = Seq(0), removed = Some(Seq(3))), + AddData(rightInput, 20), + CheckNewAnswer((20, 30, 40)), + // states + // left: (empty -- 20 removed after matching right 20 via getJoinedRowsAndRemoveMatched) + // right: 21, 20 + assertNumStateRows( + total = Seq(2), updated = Seq(1), + droppedByWatermark = Seq(0), removed = Some(Seq(1))) + ) + } + + testWithAppendAndUpdate("left semi join with watermark range condition") { outputMode => + val (leftInput, rightInput, joined) = setupJoinWithRangeCondition("left_semi") + + testStream(joined, outputMode)( + AddData(leftInput, (1, 5), (3, 5)), + CheckNewAnswer(), + // states + // left: (1, 5), (3, 5) + // right: nothing + assertNumStateRows( + total = Seq(2), updated = Seq(2), + droppedByWatermark = Seq(0), removed = Some(Seq(0))), + AddData(rightInput, (1, 10), (2, 5)), + // Match left row in the state. Matched left row (1, 5) is immediately removed from state + // via getJoinedRowsAndRemoveMatched. + CheckNewAnswer((1, 5)), + // states + // left: (3, 5) -- (1, 5) removed after matching right (1, 10) + // right: (1, 10), (2, 5) + assertNumStateRows( + total = Seq(3), updated = Seq(2), + droppedByWatermark = Seq(0), removed = Some(Seq(1))), + AddData(rightInput, (1, 9)), + // No match as left row (1, 5) was already removed from state. + CheckNewAnswer(), + // states + // left: (3, 5) + // right: (1, 10), (2, 5), (1, 9) + assertNumStateRows( + total = Seq(4), updated = Seq(1), + droppedByWatermark = Seq(0), removed = Some(Seq(0))), + // Increase event time watermark to 20s by adding data with time = 30s on both inputs. + AddData(leftInput, (1, 7), (1, 30)), + CheckNewAnswer((1, 7)), + // states + // left: (3, 5), (1, 30) + // right: (1, 10), (2, 5), (1, 9) + assertNumStateRows( + total = Seq(5), updated = Seq(1), + droppedByWatermark = Seq(0), removed = Some(Seq(0))), + // Watermark = 30 - 10 = 20, no matched row. + AddData(rightInput, (0, 30)), + CheckNewAnswer(), + // states + // left: (1, 30) + // right: (0, 30) + // + // states evicted + // left: (3, 5) (below watermark = 20) + // right: (1, 10), (2, 5), (1, 9) (below watermark = 20) + assertNumStateRows( + total = Seq(2), updated = Seq(1), + droppedByWatermark = Seq(0), removed = Some(Seq(4))) + ) + } + + testWithAppendAndUpdate("self left semi join") { outputMode => + val (inputStream, query) = setupSelfJoin("left_semi") + + testStream(query, outputMode)( + AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + CheckNewAnswer((2, 2), (4, 4)), + // batch 1 - global watermark = 0 + // states + // left: (none - left 2, 4 matched right in the same batch, emitted without storing) + // (left rows with value % 2 != 0 is filtered per [[PushPredicateThroughJoin]]) + // right: (2, 2L), (4, 4L) + // (right rows with value % 2 != 0 is filtered per [[PushPredicateThroughJoin]]) + assertNumStateRows( + total = Seq(2), updated = Seq(2), + droppedByWatermark = Seq(0), removed = Some(Seq(0))), + AddData(inputStream, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)), + CheckNewAnswer((6, 6), (8, 8), (10, 10)), + // batch 2 - global watermark = 5 + // states + // left: (none - left 6, 8, 10 matched right in the same batch, emitted without storing) + // right: (6, 6L), (8, 8L), (10, 10L) + // + // states evicted + // right: (2, 2L), (4, 4L) + assertNumStateRows( + total = Seq(3), updated = Seq(3), + droppedByWatermark = Seq(0), removed = Some(Seq(2))), + AddData(inputStream, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 15L)), + CheckNewAnswer((12, 12), (14, 14)), + // batch 3 - global watermark = 9 + // states + // left: (none - left 12, 14 matched right in the same batch, emitted without storing) + // right: (10, 10L), (12, 12L), (14, 14L) + // + // states evicted + // right: (6, 6L), (8, 8L) + assertNumStateRows( + total = Seq(3), updated = Seq(2), + droppedByWatermark = Seq(0), removed = Some(Seq(2))) + ) } } From 25390da6f0664ab79946538fe346cf95495e0cd0 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 10 Apr 2026 08:54:21 +0900 Subject: [PATCH 5/5] Review comments --- .../join/StreamingSymmetricHashJoinExec.scala | 5 +++-- .../sql/streaming/MultiStatefulOperatorsSuite.scala | 12 ++++-------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala index ddb3e7c862b1e..4346f1096a15f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala @@ -145,7 +145,7 @@ case class StreamingSymmetricHashJoinExec( stateFormatVersion: Int, left: SparkPlan, right: SparkPlan, - outputMode: Option[OutputMode] = None) + outputMode: Option[OutputMode]) extends BinaryExecNode with StateStoreWriter with SchemaValidationUtils { def this( @@ -161,7 +161,8 @@ case class StreamingSymmetricHashJoinExec( leftKeys, rightKeys, joinType, JoinConditionSplitPredicates(condition, left, right), stateInfo = None, eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, - stateWatermarkPredicates = JoinStateWatermarkPredicates(), stateFormatVersion, left, right) + stateWatermarkPredicates = JoinStateWatermarkPredicates(), stateFormatVersion, left, right, + None) } if (stateFormatVersion < 2 && joinType != Inner) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala index 977078a649e98..cd901deae8e14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala @@ -944,15 +944,13 @@ class MultiStatefulOperatorsSuite testWithAppendAndUpdate("dedup on both sides -> stream-stream inner join") { outputMode => val input1 = MemoryStream[Int] val inputDF1 = input1.toDF() - .withColumnRenamed("value", "value1") - .withColumn("eventTime1", timestamp_seconds($"value1")) + .select($"value".as("value1"), timestamp_seconds($"value").as("eventTime1")) .withWatermark("eventTime1", "10 seconds") .dropDuplicates("value1", "eventTime1") val input2 = MemoryStream[Int] val inputDF2 = input2.toDF() - .withColumnRenamed("value", "value2") - .withColumn("eventTime2", timestamp_seconds($"value2")) + .select($"value".as("value2"), timestamp_seconds($"value").as("eventTime2")) .withWatermark("eventTime2", "10 seconds") .dropDuplicates("value2", "eventTime2") @@ -971,14 +969,12 @@ class MultiStatefulOperatorsSuite test("stream-stream inner join -> window agg, update mode") { val input1 = MemoryStream[Int] val inputDF1 = input1.toDF() - .withColumnRenamed("value", "value1") - .withColumn("eventTime1", timestamp_seconds($"value1")) + .select($"value".as("value1"), timestamp_seconds($"value").as("eventTime1")) .withWatermark("eventTime1", "0 seconds") val input2 = MemoryStream[Int] val inputDF2 = input2.toDF() - .withColumnRenamed("value", "value2") - .withColumn("eventTime2", timestamp_seconds($"value2")) + .select($"value".as("value2"), timestamp_seconds($"value").as("eventTime2")) .withWatermark("eventTime2", "0 seconds") val stream = inputDF1.join(inputDF2, expr("eventTime1 = eventTime2"), "inner")