From 2c2a2adc327564de69cbe8c491034ebacb2d2c77 Mon Sep 17 00:00:00 2001 From: jingz-db Date: Wed, 3 Apr 2024 12:21:29 +0900 Subject: [PATCH] [SPARK-47655][SS] Integrate timer with Initial State handling for state-v2 ### What changes were proposed in this pull request? We want to support timer feature during initial state handling. Currently, if grouping key introduced in initial state is never seen again in new input rows, those rows in initial state will never be emitted. We want to add a field in initial state handling function such that users can manipulate the timers with initial state rows. ### Why are the changes needed? These changes are needed to support timer with initial state. The changes are part of the work around adding new stateful streaming operator for arbitrary state mgmt that provides a bunch of new features listed in the SPIP JIRA here - https://issues.apache.org/jira/browse/SPARK-45939 ### Does this PR introduce _any_ user-facing change? Yes. We add a new `timerValues` field in `StatefulProcessorWithInitialState`: ``` /** * Function that will be invoked only in the first batch for users to process initial states. * * param key - grouping key * param initialState - A row in the initial state to be processed * param timerValues - instance of TimerValues that provides access to current processing/event * time if available */ def handleInitialState(key: K, initialState: S, timerValues: TimerValues): Unit ``` ### How was this patch tested? Add unit tests in `TransformWithStateWithInitialStateSuite` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45780 from jingz-db/timer-initState-integration. Authored-by: jingz-db Signed-off-by: Jungtaek Lim --- .../sql/streaming/StatefulProcessor.scala | 7 +- .../streaming/TransformWithStateExec.scala | 3 +- .../TransformWithStateInitialStateSuite.scala | 124 +++++++++++++++++- .../streaming/TransformWithStateSuite.scala | 60 +++++---- 4 files changed, 164 insertions(+), 30 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala index 1a61972f0ed0d..0dab85b205bad 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala @@ -100,13 +100,16 @@ private[sql] trait StatefulProcessor[K, I, O] extends Serializable { */ @Experimental @Evolving -trait StatefulProcessorWithInitialState[K, I, O, S] extends StatefulProcessor[K, I, O] { +private[sql] trait StatefulProcessorWithInitialState[K, I, O, S] + extends StatefulProcessor[K, I, O] { /** * Function that will be invoked only in the first batch for users to process initial states. * * @param key - grouping key * @param initialState - A row in the initial state to be processed + * @param timerValues - instance of TimerValues that provides access to current processing/event + * time if available */ - def handleInitialState(key: K, initialState: S): Unit + def handleInitialState(key: K, initialState: S, timerValues: TimerValues): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 36b957f9d4307..66c19fa22304e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -172,7 +172,8 @@ case class TransformWithStateExec( seenInitStateOnKey = true statefulProcessor .asInstanceOf[StatefulProcessorWithInitialState[Any, Any, Any, Any]] - .handleInitialState(keyObj, initState) + .handleInitialState(keyObj, initState, + new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForEviction)) } ImplicitGroupingKeyTracker.removeImplicitKey() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala index 9f2e2c2d9f029..1c40788011c6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.{Encoders, KeyValueGroupedDataset} +import org.apache.spark.sql.{Dataset, Encoders, KeyValueGroupedDataset} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} +import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.StreamManualClock case class InitInputRow(key: String, action: String, value: Double) case class InputRowForInitialState( @@ -86,7 +88,8 @@ class AccumulateStatefulProcessorWithInitState extends StatefulProcessorWithInitialStateTestClass[(String, Double)] { override def handleInitialState( key: String, - initialState: (String, Double)): Unit = { + initialState: (String, Double), + timerValues: TimerValues): Unit = { _valState.update(initialState._2) } @@ -115,7 +118,8 @@ class InitialStateInMemoryTestClass extends StatefulProcessorWithInitialStateTestClass[InputRowForInitialState] { override def handleInitialState( key: String, - initialState: InputRowForInitialState): Unit = { + initialState: InputRowForInitialState, + timerValues: TimerValues): Unit = { _valState.update(initialState.value) _listState.appendList(initialState.entries.toArray) val inMemoryMap = initialState.mapping @@ -125,6 +129,45 @@ class InitialStateInMemoryTestClass } } +/** + * Class to test stateful processor with initial state and processing timers. + * Timers can be registered during initial state handling and output rows can be + * emitted after timer expires even if the grouping key in initial state is not + * seen in new input rows. + */ +class StatefulProcessorWithInitialStateProcTimerClass + extends RunningCountStatefulProcessorWithProcTimeTimerUpdates + with StatefulProcessorWithInitialState[String, String, (String, String), String] { + override def handleInitialState( + key: String, + initialState: String, + timerValues: TimerValues): Unit = { + // keep a _countState to count the occurrence of grouping key + // Will register a timer if key "a" is met for the first time + processUnexpiredRows(key, 0L, 1L, timerValues) + } +} + +/** + * Class to test stateful processor with initial state and event timers. + * Timers can be registered during initial state handling and output rows can be + * emitted after timer expires even if the grouping key in initial state is not + * seen in new input rows. + */ +class StatefulProcessorWithInitialStateEventTimerClass extends MaxEventTimeStatefulProcessor + with StatefulProcessorWithInitialState[String, (String, Long), (String, Int), (String, Long)]{ + override def handleInitialState( + key: String, + initialState: (String, Long), + timerValues: TimerValues): Unit = { + // keep a _maxEventTimeState to track the max eventTime seen so far + // register a timer if bigger eventTime is seen + val maxEventTimeSec = math.max(initialState._2, + _maxEventTimeState.getOption().getOrElse(0L)) + processUnexpiredRows(maxEventTimeSec) + } +} + /** * Class that adds tests for transformWithState stateful * streaming operator with user-defined initial state @@ -290,4 +333,79 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest ) } } + + test("transformWithStateWithInitialState - streaming with processing time timer, " + + "can emit expired initial state rows when grouping key is not received for new input rows") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + val clock = new StreamManualClock + + val initDf = Seq("a", "b").toDS().groupByKey(x => x) + val inputData = MemoryStream[String] + val result = inputData.toDS().groupByKey(x => x) + .transformWithState( + new StatefulProcessorWithInitialStateProcTimerClass(), + TimeoutMode.ProcessingTime(), + OutputMode.Update(), + initDf) + + testStream(result, OutputMode.Update())( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "c"), + AdvanceManualClock(1 * 1000), + // registered timer for "a" and "b" is 6000, first batch is processed at ts = 1000 + CheckNewAnswer(("c", "1")), + + AddData(inputData, "c"), + AdvanceManualClock(6 * 1000), // ts = 7000, "a" expires + // "a" is never received for new input rows, but emit here because it expires + CheckNewAnswer(("a", "-1"), ("c", "2")), + + AddData(inputData, "a", "a"), + AdvanceManualClock(1 * 1000), // ts = 8000, register timer for "a" as ts = 15500 + CheckNewAnswer(("a", "3")), + + AddData(inputData, "b"), + AdvanceManualClock(8 * 1000), // ts = 16000, timer for "a" expired + CheckNewAnswer(("a", "-1"), ("b", "2")), // initial state for "b" is also processed + StopStream + ) + } + } + + test("transformWithStateWithInitialState - streaming with event time timer, " + + "can emit expired initial state rows when grouping key is not received for new input rows") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + + def eventTimeDf(ds: Dataset[(String, Int)]): KeyValueGroupedDataset[String, (String, Long)] = + ds.select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)].groupByKey(_._1) + + val initDf = eventTimeDf(Seq(("a", 11), ("b", 5)).toDS()) + + val inputData = MemoryStream[(String, Int)] + val result = eventTimeDf(inputData.toDS()) + .transformWithState( + new StatefulProcessorWithInitialStateEventTimerClass(), + TimeoutMode.EventTime(), + OutputMode.Update(), + initDf) + + testStream(result, OutputMode.Update())( + StartStream(), + AddData(inputData, ("a", 13), ("a", 15)), + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("c", 7)), // Add data older than watermark for "a" and "b" + CheckNewAnswer(("c", 7)), + + AddData(inputData, ("d", 31)), // Add data newer than watermark + // "b" is never received for new input rows, but emit here because it expired + CheckNewAnswer(("a", -1), ("b", -1), ("c", -1), ("d", 31)), + StopStream + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 24e68e3db9d8e..2fd1eac179daa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -115,6 +115,27 @@ class RunningCountStatefulProcessorWithProcTimeTimerUpdates Iterator((key, "-1")) } + protected def processUnexpiredRows( + key: String, + currCount: Long, + count: Long, + timerValues: TimerValues): Unit = { + _countState.update(count) + if (key == "a") { + var nextTimerTs: Long = 0L + if (currCount == 0) { + nextTimerTs = timerValues.getCurrentProcessingTimeInMs() + 5000 + getHandle.registerTimer(nextTimerTs) + _timerState.update(nextTimerTs) + } else if (currCount == 1) { + getHandle.deleteTimer(_timerState.get()) + nextTimerTs = timerValues.getCurrentProcessingTimeInMs() + 7500 + getHandle.registerTimer(nextTimerTs) + _timerState.update(nextTimerTs) + } + } + } + override def handleInputRows( key: String, inputRows: Iterator[String], @@ -125,20 +146,7 @@ class RunningCountStatefulProcessorWithProcTimeTimerUpdates } else { val currCount = _countState.getOption().getOrElse(0L) val count = currCount + inputRows.size - _countState.update(count) - if (key == "a") { - var nextTimerTs: Long = 0L - if (currCount == 0) { - nextTimerTs = timerValues.getCurrentProcessingTimeInMs() + 5000 - getHandle.registerTimer(nextTimerTs) - _timerState.update(nextTimerTs) - } else if (currCount == 1) { - getHandle.deleteTimer(_timerState.get()) - nextTimerTs = timerValues.getCurrentProcessingTimeInMs() + 7500 - getHandle.registerTimer(nextTimerTs) - _timerState.update(nextTimerTs) - } - } + processUnexpiredRows(key, currCount, count, timerValues) Iterator((key, count.toString)) } } @@ -192,12 +200,24 @@ class MaxEventTimeStatefulProcessor _timerState = getHandle.getValueState[Long]("timerState", Encoders.scalaLong) } + protected def processUnexpiredRows(maxEventTimeSec: Long): Unit = { + val timeoutDelaySec = 5 + val timeoutTimestampMs = (maxEventTimeSec + timeoutDelaySec) * 1000 + _maxEventTimeState.update(maxEventTimeSec) + + val registeredTimerMs: Long = _timerState.getOption().getOrElse(0L) + if (registeredTimerMs < timeoutTimestampMs) { + getHandle.deleteTimer(registeredTimerMs) + getHandle.registerTimer(timeoutTimestampMs) + _timerState.update(timeoutTimestampMs) + } + } + override def handleInputRows( key: String, inputRows: Iterator[(String, Long)], timerValues: TimerValues, expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Int)] = { - val timeoutDelaySec = 5 if (expiredTimerInfo.isValid()) { _maxEventTimeState.clear() Iterator((key, -1)) @@ -205,15 +225,7 @@ class MaxEventTimeStatefulProcessor val valuesSeq = inputRows.toSeq val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, _maxEventTimeState.getOption().getOrElse(0L)) - val timeoutTimestampMs = (maxEventTimeSec + timeoutDelaySec) * 1000 - _maxEventTimeState.update(maxEventTimeSec) - - val registeredTimerMs: Long = _timerState.getOption().getOrElse(0L) - if (registeredTimerMs < timeoutTimestampMs) { - getHandle.deleteTimer(registeredTimerMs) - getHandle.registerTimer(timeoutTimestampMs) - _timerState.update(timeoutTimestampMs) - } + processUnexpiredRows(maxEventTimeSec) Iterator((key, maxEventTimeSec.toInt)) } }