Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ case class StreamingSymmetricHashJoinExec(
private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes)

private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match {
case Some(JoinStateKeyWatermarkPredicate(expr, _)) =>
case Some(JoinStateKeyWatermarkPredicate(expr, _, _)) =>
// inputSchema can be empty as expr should only have BoundReferences and does not require
// the schema to generated predicate. See [[StreamingSymmetricHashJoinHelper]].
Predicate.create(expr, Seq.empty).eval _
Expand All @@ -672,7 +672,7 @@ case class StreamingSymmetricHashJoinExec(
}

private[this] val stateValueWatermarkPredicateFunc = stateWatermarkPredicate match {
case Some(JoinStateValueWatermarkPredicate(expr, _)) =>
case Some(JoinStateValueWatermarkPredicate(expr, _, _)) =>
Predicate.create(expr, inputAttributes).eval _
case _ =>
Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
Expand Down Expand Up @@ -893,21 +893,25 @@ case class StreamingSymmetricHashJoinExec(
*/
def removeOldState(): Long = {
stateWatermarkPredicate match {
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) =>
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
joinStateManager match {
case s: SupportsEvictByCondition =>
s.evictByKeyCondition(stateKeyWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
s.evictByTimestamp(
watermarkMsToStateTimestamp(stateWatermark),
prevStateWatermark.map(watermarkMsToStateTimestamp))
}
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
joinStateManager match {
case s: SupportsEvictByCondition =>
s.evictByValueCondition(stateValueWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
s.evictByTimestamp(
watermarkMsToStateTimestamp(stateWatermark),
prevStateWatermark.map(watermarkMsToStateTimestamp))
}
case _ => 0L
}
Expand All @@ -925,21 +929,25 @@ case class StreamingSymmetricHashJoinExec(
*/
def removeAndReturnOldState(): Iterator[KeyToValuePair] = {
stateWatermarkPredicate match {
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) =>
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
joinStateManager match {
case s: SupportsEvictByCondition =>
s.evictAndReturnByKeyCondition(stateKeyWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
s.evictAndReturnByTimestamp(
watermarkMsToStateTimestamp(stateWatermark),
prevStateWatermark.map(watermarkMsToStateTimestamp))
}
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
joinStateManager match {
case s: SupportsEvictByCondition =>
s.evictAndReturnByValueCondition(stateValueWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
s.evictAndReturnByTimestamp(
watermarkMsToStateTimestamp(stateWatermark),
prevStateWatermark.map(watermarkMsToStateTimestamp))
}
case _ => Iterator.empty
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,18 @@ object StreamingSymmetricHashJoinHelper extends Logging {
override def toString: String = s"$desc: $expr"
}
/** Predicate for watermark on state keys */
case class JoinStateKeyWatermarkPredicate(expr: Expression, stateWatermark: Long)
case class JoinStateKeyWatermarkPredicate(
expr: Expression,
stateWatermark: Long,
prevStateWatermark: Option[Long] = None)
extends JoinStateWatermarkPredicate {
def desc: String = "key predicate"
}
/** Predicate for watermark on state values */
case class JoinStateValueWatermarkPredicate(expr: Expression, stateWatermark: Long)
case class JoinStateValueWatermarkPredicate(
expr: Expression,
stateWatermark: Long,
prevStateWatermark: Option[Long] = None)
extends JoinStateWatermarkPredicate {
def desc: String = "value predicate"
}
Expand Down Expand Up @@ -185,6 +191,7 @@ object StreamingSymmetricHashJoinHelper extends Logging {
rightKeys: Seq[Expression],
condition: Option[Expression],
eventTimeWatermarkForEviction: Option[Long],
eventTimeWatermarkForLateEvents: Option[Long],
useFirstEventTimeColumn: Boolean): JoinStateWatermarkPredicates = {

// Perform assertions against multiple event time columns in the same DataFrame. This method
Expand Down Expand Up @@ -215,20 +222,30 @@ object StreamingSymmetricHashJoinHelper extends Logging {
expr.map { e =>
// watermarkExpression only provides the expression when eventTimeWatermarkForEviction
// is defined
JoinStateKeyWatermarkPredicate(e, eventTimeWatermarkForEviction.get)
JoinStateKeyWatermarkPredicate(
e,
eventTimeWatermarkForEviction.get,
eventTimeWatermarkForLateEvents)
}
} else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs
val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark(
attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes),
attributesWithEventWatermark = AttributeSet(otherSideInputAttributes),
condition,
eventTimeWatermarkForEviction)
val prevStateValueWatermark = eventTimeWatermarkForLateEvents.flatMap { _ =>
StreamingJoinHelper.getStateValueWatermark(
attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes),
attributesWithEventWatermark = AttributeSet(otherSideInputAttributes),
condition,
eventTimeWatermarkForLateEvents)
}
val inputAttributeWithWatermark = oneSideInputAttributes.find(_.metadata.contains(delayKey))
val expr = watermarkExpression(inputAttributeWithWatermark, stateValueWatermark)
expr.map { e =>
// watermarkExpression only provides the expression when eventTimeWatermarkForEviction
// is defined
JoinStateValueWatermarkPredicate(e, stateValueWatermark.get)
JoinStateValueWatermarkPredicate(e, stateValueWatermark.get, prevStateValueWatermark)
}
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,28 @@ trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
import SymmetricHashJoinStateManager._

/** Evict the state by timestamp. Returns the number of values evicted. */
def evictByTimestamp(endTimestamp: Long): Long
/**
* Evict the state by timestamp. Returns the number of values evicted.
*
* @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp.
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are
* assumed to have been evicted already (e.g. from the previous batch). When provided,
* the scan starts from startTimestamp + 1.
*/
def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long

/**
* Evict the state by timestamp and return the evicted key-value pairs.
*
* It is caller's responsibility to consume the whole iterator.
*
* @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp.
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are
* assumed to have been evicted already (e.g. from the previous batch). When provided,
* the scan starts from startTimestamp + 1.
*/
def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
def evictAndReturnByTimestamp(
endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair]
}

/**
Expand Down Expand Up @@ -507,9 +520,9 @@ class SymmetricHashJoinStateManagerV4(
}
}

override def evictByTimestamp(endTimestamp: Long): Long = {
override def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long = {
var removed = 0L
tsWithKey.scanEvictedKeys(endTimestamp).foreach { evicted =>
tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).foreach { evicted =>
val key = evicted.key
val timestamp = evicted.timestamp
val numValues = evicted.numValues
Expand All @@ -523,10 +536,11 @@ class SymmetricHashJoinStateManagerV4(
removed
}

override def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair] = {
override def evictAndReturnByTimestamp(
endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair] = {
val reusableKeyToValuePair = KeyToValuePair()

tsWithKey.scanEvictedKeys(endTimestamp).flatMap { evicted =>
tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).flatMap { evicted =>
val key = evicted.key
val timestamp = evicted.timestamp
val values = keyWithTsToValues.get(key, timestamp)
Expand Down Expand Up @@ -647,14 +661,28 @@ class SymmetricHashJoinStateManagerV4(

/**
* Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp.
* Skips entries before minTs and stops iterating past maxTs (timestamps are sorted).
* When maxTs is bounded (< Long.MaxValue), uses rangeScanWithMultiValues for efficient
* range access; falls back to prefixScan otherwise to stay within the key's scope.
*
* When prefixScan is used (maxTs == Long.MaxValue), entries outside [minTs, maxTs] are
* filtered out so both code paths produce identical results.
*/
def getValuesInRange(
key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = {
val reusableGetValuesResult = new GetValuesResult()
// Only use rangeScan when maxTs < Long.MaxValue, since rangeScan requires
// an exclusive end key (maxTs + 1) which would overflow at Long.MaxValue.
val useRangeScan = maxTs < Long.MaxValue

new NextIterator[GetValuesResult] {
private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName)
private val iter = if (useRangeScan) {
val startKey = createKeyRow(key, minTs).copy()
// rangeScanWithMultiValues endKey is exclusive, so use maxTs + 1
val endKey = Some(createKeyRow(key, maxTs + 1))
stateStore.rangeScanWithMultiValues(Some(startKey), endKey, colFamilyName)
} else {
stateStore.prefixScanWithMultiValues(key, colFamilyName)
}

private var currentTs = -1L
private var pastUpperBound = false
Expand Down Expand Up @@ -757,6 +785,8 @@ class SymmetricHashJoinStateManagerV4(
isInternal = true
)

// Returns an UnsafeRow backed by a reused projection buffer. Callers that need to
// hold the row beyond the immediate state store call must invoke copy() on the result.
private def createKeyRow(key: UnsafeRow, timestamp: Long): UnsafeRow = {
TimestampKeyStateEncoder.attachTimestamp(
attachTimestampProjection, keySchemaWithTimestamp, key, timestamp)
Expand All @@ -772,9 +802,63 @@ class SymmetricHashJoinStateManagerV4(

case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int)

// NOTE: This assumes we consume the whole iterator to trigger completion.
def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = {
val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName)
private def defaultInternalRow(schema: StructType): InternalRow = {
InternalRow.fromSeq(schema.map(f => Literal.default(f.dataType).value))
}

/**
* Reusable default key row for scan boundary construction. Safe to reuse because
* createKeyRow only reads this row (via BoundReference evaluations) and writes to
* the projection's own internal buffer.
*/
private lazy val defaultKey: UnsafeRow = UnsafeProjection.create(keySchema)
.apply(defaultInternalRow(keySchema))

/**
* Build a scan boundary row for rangeScan. The TsWithKeyTypeStore uses
* TimestampAsPrefixKeyStateEncoder, which encodes the row as [timestamp][key_fields].
* We need a full-schema row (not just the timestamp) because the encoder expects all
* key columns to be present. Default values are used for the key fields since only the
* timestamp matters for ordering in the prefix encoder.
*/
private def createScanBoundaryRow(timestamp: Long): UnsafeRow = {
createKeyRow(defaultKey, timestamp).copy()
}

/**
* Scan keys eligible for eviction within the timestamp range.
*
* This assumes we consume the whole iterator to trigger completion.
*
* @param endTimestamp Inclusive upper bound: entries with timestamp <= endTimestamp are
* eligible for eviction.
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp
* are assumed to have been evicted already. The scan starts from startTimestamp + 1.
*/
def scanEvictedKeys(
endTimestamp: Long,
startTimestamp: Option[Long] = None): Iterator[EvictedKeysResult] = {
// If startTimestamp == Long.MaxValue, everything has already been evicted;
// nothing can match, so return immediately.
if (startTimestamp.contains(Long.MaxValue)) {
return Iterator.empty
}

// rangeScanWithMultiValues: startKey is inclusive, endKey is exclusive.
// startTimestamp is exclusive (already evicted), so we seek from st + 1.
val startKeyRow = startTimestamp.map { st =>
createScanBoundaryRow(st + 1)
}
// endTimestamp is inclusive, so we use endTimestamp + 1 as the exclusive upper bound.
// When endTimestamp == Long.MaxValue we cannot add 1, so endKeyRow is None. This is
// safe because rangeScanWithMultiValues with no end key uses the column-family prefix
// as the upper bound, naturally scoping the scan within this column family.
val endKeyRow = if (endTimestamp < Long.MaxValue) {
Some(createScanBoundaryRow(endTimestamp + 1))
} else {
None
}
val evictIterator = stateStore.rangeScanWithMultiValues(startKeyRow, endKeyRow, colFamilyName)
new NextIterator[EvictedKeysResult]() {
var currentKeyRow: UnsafeRow = null
var currentEventTime: Long = -1L
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,19 @@ class IncrementalExecution(
case j: StreamingSymmetricHashJoinExec =>
val iwLateEvents = inputWatermarkForLateEvents(j.stateInfo.get)
val iwEviction = inputWatermarkForEviction(j.stateInfo.get)
// Only use the late-events watermark as the scan lower bound when a previous
// batch actually existed. In the very first batch the watermark propagation
// yields Some(0) even though no state has been evicted yet, which would
// incorrectly skip entries at timestamp 0.
val prevBatchLateEventsWm =
if (prevOffsetSeqMetadata.isDefined) iwLateEvents else None
j.copy(
eventTimeWatermarkForLateEvents = iwLateEvents,
eventTimeWatermarkForEviction = iwEviction,
stateWatermarkPredicates =
StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates(
j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full,
iwEviction, !allowMultipleStatefulOperators)
iwEviction, prevBatchLateEventsWm, !allowMultipleStatefulOperators)
)
}
}
Expand Down
Loading