Skip to content

Commit

Permalink
[SPARK-47655][SS] Integrate timer with Initial State handling for sta…
Browse files Browse the repository at this point in the history
…te-v2

### What changes were proposed in this pull request?

We want to support timer feature during initial state handling. Currently, if grouping key introduced in initial state is never seen again in new input rows, those rows in initial state will never be emitted. We want to add a field in initial state handling function such that users can manipulate the timers with initial state rows.

### Why are the changes needed?

These changes are needed to support timer with initial state. The changes are part of the work around adding new stateful streaming operator for arbitrary state mgmt that provides a bunch of new features listed in the SPIP JIRA here - https://issues.apache.org/jira/browse/SPARK-45939

### Does this PR introduce _any_ user-facing change?

Yes. We add a new `timerValues` field in `StatefulProcessorWithInitialState`:
```
  /**
   * Function that will be invoked only in the first batch for users to process initial states.
   *
   * param key - grouping key
   * param initialState - A row in the initial state to be processed
   * param timerValues  - instance of TimerValues that provides access to current processing/event
   *                     time if available
   */
  def handleInitialState(key: K, initialState: S, timerValues: TimerValues): Unit
```

### How was this patch tested?

Add unit tests in `TransformWithStateWithInitialStateSuite`

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#45780 from jingz-db/timer-initState-integration.

Authored-by: jingz-db <jing.zhan@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
jingz-db authored and HeartSaVioR committed Apr 3, 2024
1 parent 49b7b6b commit 2c2a2ad
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,16 @@ private[sql] trait StatefulProcessor[K, I, O] extends Serializable {
*/
@Experimental
@Evolving
trait StatefulProcessorWithInitialState[K, I, O, S] extends StatefulProcessor[K, I, O] {
private[sql] trait StatefulProcessorWithInitialState[K, I, O, S]
extends StatefulProcessor[K, I, O] {

/**
* Function that will be invoked only in the first batch for users to process initial states.
*
* @param key - grouping key
* @param initialState - A row in the initial state to be processed
* @param timerValues - instance of TimerValues that provides access to current processing/event
* time if available
*/
def handleInitialState(key: K, initialState: S): Unit
def handleInitialState(key: K, initialState: S, timerValues: TimerValues): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ case class TransformWithStateExec(
seenInitStateOnKey = true
statefulProcessor
.asInstanceOf[StatefulProcessorWithInitialState[Any, Any, Any, Any]]
.handleInitialState(keyObj, initState)
.handleInitialState(keyObj, initState,
new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForEviction))
}
ImplicitGroupingKeyTracker.removeImplicitKey()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
package org.apache.spark.sql.streaming

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.{Encoders, KeyValueGroupedDataset}
import org.apache.spark.sql.{Dataset, Encoders, KeyValueGroupedDataset}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider}
import org.apache.spark.sql.functions.timestamp_seconds
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.StreamManualClock

case class InitInputRow(key: String, action: String, value: Double)
case class InputRowForInitialState(
Expand Down Expand Up @@ -86,7 +88,8 @@ class AccumulateStatefulProcessorWithInitState
extends StatefulProcessorWithInitialStateTestClass[(String, Double)] {
override def handleInitialState(
key: String,
initialState: (String, Double)): Unit = {
initialState: (String, Double),
timerValues: TimerValues): Unit = {
_valState.update(initialState._2)
}

Expand Down Expand Up @@ -115,7 +118,8 @@ class InitialStateInMemoryTestClass
extends StatefulProcessorWithInitialStateTestClass[InputRowForInitialState] {
override def handleInitialState(
key: String,
initialState: InputRowForInitialState): Unit = {
initialState: InputRowForInitialState,
timerValues: TimerValues): Unit = {
_valState.update(initialState.value)
_listState.appendList(initialState.entries.toArray)
val inMemoryMap = initialState.mapping
Expand All @@ -125,6 +129,45 @@ class InitialStateInMemoryTestClass
}
}

/**
* Class to test stateful processor with initial state and processing timers.
* Timers can be registered during initial state handling and output rows can be
* emitted after timer expires even if the grouping key in initial state is not
* seen in new input rows.
*/
class StatefulProcessorWithInitialStateProcTimerClass
extends RunningCountStatefulProcessorWithProcTimeTimerUpdates
with StatefulProcessorWithInitialState[String, String, (String, String), String] {
override def handleInitialState(
key: String,
initialState: String,
timerValues: TimerValues): Unit = {
// keep a _countState to count the occurrence of grouping key
// Will register a timer if key "a" is met for the first time
processUnexpiredRows(key, 0L, 1L, timerValues)
}
}

/**
* Class to test stateful processor with initial state and event timers.
* Timers can be registered during initial state handling and output rows can be
* emitted after timer expires even if the grouping key in initial state is not
* seen in new input rows.
*/
class StatefulProcessorWithInitialStateEventTimerClass extends MaxEventTimeStatefulProcessor
with StatefulProcessorWithInitialState[String, (String, Long), (String, Int), (String, Long)]{
override def handleInitialState(
key: String,
initialState: (String, Long),
timerValues: TimerValues): Unit = {
// keep a _maxEventTimeState to track the max eventTime seen so far
// register a timer if bigger eventTime is seen
val maxEventTimeSec = math.max(initialState._2,
_maxEventTimeState.getOption().getOrElse(0L))
processUnexpiredRows(maxEventTimeSec)
}
}

/**
* Class that adds tests for transformWithState stateful
* streaming operator with user-defined initial state
Expand Down Expand Up @@ -290,4 +333,79 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest
)
}
}

test("transformWithStateWithInitialState - streaming with processing time timer, " +
"can emit expired initial state rows when grouping key is not received for new input rows") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {
val clock = new StreamManualClock

val initDf = Seq("a", "b").toDS().groupByKey(x => x)
val inputData = MemoryStream[String]
val result = inputData.toDS().groupByKey(x => x)
.transformWithState(
new StatefulProcessorWithInitialStateProcTimerClass(),
TimeoutMode.ProcessingTime(),
OutputMode.Update(),
initDf)

testStream(result, OutputMode.Update())(
StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
AddData(inputData, "c"),
AdvanceManualClock(1 * 1000),
// registered timer for "a" and "b" is 6000, first batch is processed at ts = 1000
CheckNewAnswer(("c", "1")),

AddData(inputData, "c"),
AdvanceManualClock(6 * 1000), // ts = 7000, "a" expires
// "a" is never received for new input rows, but emit here because it expires
CheckNewAnswer(("a", "-1"), ("c", "2")),

AddData(inputData, "a", "a"),
AdvanceManualClock(1 * 1000), // ts = 8000, register timer for "a" as ts = 15500
CheckNewAnswer(("a", "3")),

AddData(inputData, "b"),
AdvanceManualClock(8 * 1000), // ts = 16000, timer for "a" expired
CheckNewAnswer(("a", "-1"), ("b", "2")), // initial state for "b" is also processed
StopStream
)
}
}

test("transformWithStateWithInitialState - streaming with event time timer, " +
"can emit expired initial state rows when grouping key is not received for new input rows") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

def eventTimeDf(ds: Dataset[(String, Int)]): KeyValueGroupedDataset[String, (String, Long)] =
ds.select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime"))
.withWatermark("eventTime", "10 seconds")
.as[(String, Long)].groupByKey(_._1)

val initDf = eventTimeDf(Seq(("a", 11), ("b", 5)).toDS())

val inputData = MemoryStream[(String, Int)]
val result = eventTimeDf(inputData.toDS())
.transformWithState(
new StatefulProcessorWithInitialStateEventTimerClass(),
TimeoutMode.EventTime(),
OutputMode.Update(),
initDf)

testStream(result, OutputMode.Update())(
StartStream(),
AddData(inputData, ("a", 13), ("a", 15)),
CheckNewAnswer(("a", 15)), // Output = max event time of a

AddData(inputData, ("c", 7)), // Add data older than watermark for "a" and "b"
CheckNewAnswer(("c", 7)),

AddData(inputData, ("d", 31)), // Add data newer than watermark
// "b" is never received for new input rows, but emit here because it expired
CheckNewAnswer(("a", -1), ("b", -1), ("c", -1), ("d", 31)),
StopStream
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,27 @@ class RunningCountStatefulProcessorWithProcTimeTimerUpdates
Iterator((key, "-1"))
}

protected def processUnexpiredRows(
key: String,
currCount: Long,
count: Long,
timerValues: TimerValues): Unit = {
_countState.update(count)
if (key == "a") {
var nextTimerTs: Long = 0L
if (currCount == 0) {
nextTimerTs = timerValues.getCurrentProcessingTimeInMs() + 5000
getHandle.registerTimer(nextTimerTs)
_timerState.update(nextTimerTs)
} else if (currCount == 1) {
getHandle.deleteTimer(_timerState.get())
nextTimerTs = timerValues.getCurrentProcessingTimeInMs() + 7500
getHandle.registerTimer(nextTimerTs)
_timerState.update(nextTimerTs)
}
}
}

override def handleInputRows(
key: String,
inputRows: Iterator[String],
Expand All @@ -125,20 +146,7 @@ class RunningCountStatefulProcessorWithProcTimeTimerUpdates
} else {
val currCount = _countState.getOption().getOrElse(0L)
val count = currCount + inputRows.size
_countState.update(count)
if (key == "a") {
var nextTimerTs: Long = 0L
if (currCount == 0) {
nextTimerTs = timerValues.getCurrentProcessingTimeInMs() + 5000
getHandle.registerTimer(nextTimerTs)
_timerState.update(nextTimerTs)
} else if (currCount == 1) {
getHandle.deleteTimer(_timerState.get())
nextTimerTs = timerValues.getCurrentProcessingTimeInMs() + 7500
getHandle.registerTimer(nextTimerTs)
_timerState.update(nextTimerTs)
}
}
processUnexpiredRows(key, currCount, count, timerValues)
Iterator((key, count.toString))
}
}
Expand Down Expand Up @@ -192,28 +200,32 @@ class MaxEventTimeStatefulProcessor
_timerState = getHandle.getValueState[Long]("timerState", Encoders.scalaLong)
}

protected def processUnexpiredRows(maxEventTimeSec: Long): Unit = {
val timeoutDelaySec = 5
val timeoutTimestampMs = (maxEventTimeSec + timeoutDelaySec) * 1000
_maxEventTimeState.update(maxEventTimeSec)

val registeredTimerMs: Long = _timerState.getOption().getOrElse(0L)
if (registeredTimerMs < timeoutTimestampMs) {
getHandle.deleteTimer(registeredTimerMs)
getHandle.registerTimer(timeoutTimestampMs)
_timerState.update(timeoutTimestampMs)
}
}

override def handleInputRows(
key: String,
inputRows: Iterator[(String, Long)],
timerValues: TimerValues,
expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Int)] = {
val timeoutDelaySec = 5
if (expiredTimerInfo.isValid()) {
_maxEventTimeState.clear()
Iterator((key, -1))
} else {
val valuesSeq = inputRows.toSeq
val maxEventTimeSec = math.max(valuesSeq.map(_._2).max,
_maxEventTimeState.getOption().getOrElse(0L))
val timeoutTimestampMs = (maxEventTimeSec + timeoutDelaySec) * 1000
_maxEventTimeState.update(maxEventTimeSec)

val registeredTimerMs: Long = _timerState.getOption().getOrElse(0L)
if (registeredTimerMs < timeoutTimestampMs) {
getHandle.deleteTimer(registeredTimerMs)
getHandle.registerTimer(timeoutTimestampMs)
_timerState.update(timeoutTimestampMs)
}
processUnexpiredRows(maxEventTimeSec)
Iterator((key, maxEventTimeSec.toInt))
}
}
Expand Down

0 comments on commit 2c2a2ad

Please sign in to comment.