diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala index 1a61972f0ed0d..0dab85b205bad 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala @@ -100,13 +100,16 @@ private[sql] trait StatefulProcessor[K, I, O] extends Serializable { */ @Experimental @Evolving -trait StatefulProcessorWithInitialState[K, I, O, S] extends StatefulProcessor[K, I, O] { +private[sql] trait StatefulProcessorWithInitialState[K, I, O, S] + extends StatefulProcessor[K, I, O] { /** * Function that will be invoked only in the first batch for users to process initial states. * * @param key - grouping key * @param initialState - A row in the initial state to be processed + * @param timerValues - instance of TimerValues that provides access to current processing/event + * time if available */ - def handleInitialState(key: K, initialState: S): Unit + def handleInitialState(key: K, initialState: S, timerValues: TimerValues): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 36b957f9d4307..66c19fa22304e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -172,7 +172,8 @@ case class TransformWithStateExec( seenInitStateOnKey = true statefulProcessor .asInstanceOf[StatefulProcessorWithInitialState[Any, Any, Any, Any]] - .handleInitialState(keyObj, initState) + .handleInitialState(keyObj, initState, + new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForEviction)) } ImplicitGroupingKeyTracker.removeImplicitKey() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala index 9f2e2c2d9f029..1c40788011c6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.{Encoders, KeyValueGroupedDataset} +import org.apache.spark.sql.{Dataset, Encoders, KeyValueGroupedDataset} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} +import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.StreamManualClock case class InitInputRow(key: String, action: String, value: Double) case class InputRowForInitialState( @@ -86,7 +88,8 @@ class AccumulateStatefulProcessorWithInitState extends StatefulProcessorWithInitialStateTestClass[(String, Double)] { override def handleInitialState( key: String, - initialState: (String, Double)): Unit = { + initialState: (String, Double), + timerValues: TimerValues): Unit = { _valState.update(initialState._2) } @@ -115,7 +118,8 @@ class InitialStateInMemoryTestClass extends StatefulProcessorWithInitialStateTestClass[InputRowForInitialState] { override def handleInitialState( key: String, - initialState: InputRowForInitialState): Unit = { + initialState: InputRowForInitialState, + timerValues: TimerValues): Unit = { _valState.update(initialState.value) _listState.appendList(initialState.entries.toArray) val inMemoryMap = initialState.mapping @@ -125,6 +129,45 @@ class InitialStateInMemoryTestClass } } +/** + * Class to test stateful processor with initial state and processing timers. + * Timers can be registered during initial state handling and output rows can be + * emitted after timer expires even if the grouping key in initial state is not + * seen in new input rows. + */ +class StatefulProcessorWithInitialStateProcTimerClass + extends RunningCountStatefulProcessorWithProcTimeTimerUpdates + with StatefulProcessorWithInitialState[String, String, (String, String), String] { + override def handleInitialState( + key: String, + initialState: String, + timerValues: TimerValues): Unit = { + // keep a _countState to count the occurrence of grouping key + // Will register a timer if key "a" is met for the first time + processUnexpiredRows(key, 0L, 1L, timerValues) + } +} + +/** + * Class to test stateful processor with initial state and event timers. + * Timers can be registered during initial state handling and output rows can be + * emitted after timer expires even if the grouping key in initial state is not + * seen in new input rows. + */ +class StatefulProcessorWithInitialStateEventTimerClass extends MaxEventTimeStatefulProcessor + with StatefulProcessorWithInitialState[String, (String, Long), (String, Int), (String, Long)]{ + override def handleInitialState( + key: String, + initialState: (String, Long), + timerValues: TimerValues): Unit = { + // keep a _maxEventTimeState to track the max eventTime seen so far + // register a timer if bigger eventTime is seen + val maxEventTimeSec = math.max(initialState._2, + _maxEventTimeState.getOption().getOrElse(0L)) + processUnexpiredRows(maxEventTimeSec) + } +} + /** * Class that adds tests for transformWithState stateful * streaming operator with user-defined initial state @@ -290,4 +333,79 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest ) } } + + test("transformWithStateWithInitialState - streaming with processing time timer, " + + "can emit expired initial state rows when grouping key is not received for new input rows") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + val clock = new StreamManualClock + + val initDf = Seq("a", "b").toDS().groupByKey(x => x) + val inputData = MemoryStream[String] + val result = inputData.toDS().groupByKey(x => x) + .transformWithState( + new StatefulProcessorWithInitialStateProcTimerClass(), + TimeoutMode.ProcessingTime(), + OutputMode.Update(), + initDf) + + testStream(result, OutputMode.Update())( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "c"), + AdvanceManualClock(1 * 1000), + // registered timer for "a" and "b" is 6000, first batch is processed at ts = 1000 + CheckNewAnswer(("c", "1")), + + AddData(inputData, "c"), + AdvanceManualClock(6 * 1000), // ts = 7000, "a" expires + // "a" is never received for new input rows, but emit here because it expires + CheckNewAnswer(("a", "-1"), ("c", "2")), + + AddData(inputData, "a", "a"), + AdvanceManualClock(1 * 1000), // ts = 8000, register timer for "a" as ts = 15500 + CheckNewAnswer(("a", "3")), + + AddData(inputData, "b"), + AdvanceManualClock(8 * 1000), // ts = 16000, timer for "a" expired + CheckNewAnswer(("a", "-1"), ("b", "2")), // initial state for "b" is also processed + StopStream + ) + } + } + + test("transformWithStateWithInitialState - streaming with event time timer, " + + "can emit expired initial state rows when grouping key is not received for new input rows") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + + def eventTimeDf(ds: Dataset[(String, Int)]): KeyValueGroupedDataset[String, (String, Long)] = + ds.select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)].groupByKey(_._1) + + val initDf = eventTimeDf(Seq(("a", 11), ("b", 5)).toDS()) + + val inputData = MemoryStream[(String, Int)] + val result = eventTimeDf(inputData.toDS()) + .transformWithState( + new StatefulProcessorWithInitialStateEventTimerClass(), + TimeoutMode.EventTime(), + OutputMode.Update(), + initDf) + + testStream(result, OutputMode.Update())( + StartStream(), + AddData(inputData, ("a", 13), ("a", 15)), + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("c", 7)), // Add data older than watermark for "a" and "b" + CheckNewAnswer(("c", 7)), + + AddData(inputData, ("d", 31)), // Add data newer than watermark + // "b" is never received for new input rows, but emit here because it expired + CheckNewAnswer(("a", -1), ("b", -1), ("c", -1), ("d", 31)), + StopStream + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 24e68e3db9d8e..2fd1eac179daa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -115,6 +115,27 @@ class RunningCountStatefulProcessorWithProcTimeTimerUpdates Iterator((key, "-1")) } + protected def processUnexpiredRows( + key: String, + currCount: Long, + count: Long, + timerValues: TimerValues): Unit = { + _countState.update(count) + if (key == "a") { + var nextTimerTs: Long = 0L + if (currCount == 0) { + nextTimerTs = timerValues.getCurrentProcessingTimeInMs() + 5000 + getHandle.registerTimer(nextTimerTs) + _timerState.update(nextTimerTs) + } else if (currCount == 1) { + getHandle.deleteTimer(_timerState.get()) + nextTimerTs = timerValues.getCurrentProcessingTimeInMs() + 7500 + getHandle.registerTimer(nextTimerTs) + _timerState.update(nextTimerTs) + } + } + } + override def handleInputRows( key: String, inputRows: Iterator[String], @@ -125,20 +146,7 @@ class RunningCountStatefulProcessorWithProcTimeTimerUpdates } else { val currCount = _countState.getOption().getOrElse(0L) val count = currCount + inputRows.size - _countState.update(count) - if (key == "a") { - var nextTimerTs: Long = 0L - if (currCount == 0) { - nextTimerTs = timerValues.getCurrentProcessingTimeInMs() + 5000 - getHandle.registerTimer(nextTimerTs) - _timerState.update(nextTimerTs) - } else if (currCount == 1) { - getHandle.deleteTimer(_timerState.get()) - nextTimerTs = timerValues.getCurrentProcessingTimeInMs() + 7500 - getHandle.registerTimer(nextTimerTs) - _timerState.update(nextTimerTs) - } - } + processUnexpiredRows(key, currCount, count, timerValues) Iterator((key, count.toString)) } } @@ -192,12 +200,24 @@ class MaxEventTimeStatefulProcessor _timerState = getHandle.getValueState[Long]("timerState", Encoders.scalaLong) } + protected def processUnexpiredRows(maxEventTimeSec: Long): Unit = { + val timeoutDelaySec = 5 + val timeoutTimestampMs = (maxEventTimeSec + timeoutDelaySec) * 1000 + _maxEventTimeState.update(maxEventTimeSec) + + val registeredTimerMs: Long = _timerState.getOption().getOrElse(0L) + if (registeredTimerMs < timeoutTimestampMs) { + getHandle.deleteTimer(registeredTimerMs) + getHandle.registerTimer(timeoutTimestampMs) + _timerState.update(timeoutTimestampMs) + } + } + override def handleInputRows( key: String, inputRows: Iterator[(String, Long)], timerValues: TimerValues, expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Int)] = { - val timeoutDelaySec = 5 if (expiredTimerInfo.isValid()) { _maxEventTimeState.clear() Iterator((key, -1)) @@ -205,15 +225,7 @@ class MaxEventTimeStatefulProcessor val valuesSeq = inputRows.toSeq val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, _maxEventTimeState.getOption().getOrElse(0L)) - val timeoutTimestampMs = (maxEventTimeSec + timeoutDelaySec) * 1000 - _maxEventTimeState.update(maxEventTimeSec) - - val registeredTimerMs: Long = _timerState.getOption().getOrElse(0L) - if (registeredTimerMs < timeoutTimestampMs) { - getHandle.deleteTimer(registeredTimerMs) - getHandle.registerTimer(timeoutTimestampMs) - _timerState.update(timeoutTimestampMs) - } + processUnexpiredRows(maxEventTimeSec) Iterator((key, maxEventTimeSec.toInt)) } }