diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 0187f985770ba..15b851a78d629 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3824,6 +3824,12 @@ ], "sqlState" : "42802" }, + "STATEFUL_PROCESSOR_DUPLICATE_STATE_VARIABLE_DEFINED" : { + "message" : [ + "State variable with name has already been defined in the StatefulProcessor." + ], + "sqlState" : "42802" + }, "STATEFUL_PROCESSOR_INCORRECT_TIME_MODE_TO_ASSIGN_TTL" : { "message" : [ "Cannot use TTL for state= in timeMode=, use TimeMode.ProcessingTime() instead." @@ -3873,12 +3879,24 @@ ], "sqlState" : "42802" }, + "STATE_STORE_INVALID_CONFIG_AFTER_RESTART" : { + "message" : [ + "Cannot change from to between restarts. Please set to , or restart with a new checkpoint directory." + ], + "sqlState" : "42K06" + }, "STATE_STORE_INVALID_PROVIDER" : { "message" : [ "The given State Store Provider does not extend org.apache.spark.sql.execution.streaming.state.StateStoreProvider." ], "sqlState" : "42K06" }, + "STATE_STORE_INVALID_VARIABLE_TYPE_CHANGE" : { + "message" : [ + "Cannot change to between query restarts. Please set to , or restart with a new checkpoint directory." + ], + "sqlState" : "42K06" + }, "STATE_STORE_KEY_ROW_FORMAT_VALIDATION_FAILURE" : { "message" : [ "The streaming query failed to validate written state for key row.", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 567fb1b98f14c..c09f9fa7b994b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter} +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -213,6 +213,27 @@ class IncrementalExecution( statefulOp match { case ssw: StateStoreWriter => val metadata = ssw.operatorStateMetadata(stateSchemaPaths) + // validate metadata + if (isFirstBatch && currentBatchId != 0) { + // If we are restarting from a different checkpoint directory + // there may be a mismatch between the stateful operators in the + // physical plan and the metadata. + val oldMetadata = try { + OperatorStateMetadataReader.createReader( + new Path(checkpointLocation, ssw.getStateInfo.operatorId.toString), + hadoopConf, ssw.operatorStateMetadataVersion).read() + } catch { + case e: Exception => + logWarning(log"Error reading metadata path for stateful operator. This " + + log"may due to no prior committed batch, or previously run on lower " + + log"versions: ${MDC(ERROR, e.getMessage)}") + None + } + oldMetadata match { + case Some(oldMetadata) => ssw.validateNewMetadata(oldMetadata, metadata) + case None => + } + } val metadataWriter = OperatorStateMetadataWriter.createWriter( new Path(checkpointLocation, ssw.getStateInfo.operatorId.toString), hadoopConf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 44e2e6838f4a3..942d395dec0e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -301,17 +301,32 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi extends StatefulProcessorHandleImplBase(timeMode, keyExprEnc) { // Because this is only happening on the driver side, there is only - // one task modifying and accessing this map at a time + // one task modifying and accessing these maps at a time private[sql] val columnFamilySchemas: mutable.Map[String, StateStoreColFamilySchema] = new mutable.HashMap[String, StateStoreColFamilySchema]() + private val stateVariableInfos: mutable.Map[String, TransformWithStateVariableInfo] = + new mutable.HashMap[String, TransformWithStateVariableInfo]() + def getColumnFamilySchemas: Map[String, StateStoreColFamilySchema] = columnFamilySchemas.toMap + def getStateVariableInfos: Map[String, TransformWithStateVariableInfo] = stateVariableInfos.toMap + + private def checkIfDuplicateVariableDefined(stateVarName: String): Unit = { + if (columnFamilySchemas.contains(stateVarName)) { + throw StateStoreErrors.duplicateStateVariableDefined(stateVarName) + } + } + override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state", PRE_INIT) val colFamilySchema = StateStoreColumnFamilySchemaUtils. getValueStateSchema(stateName, keyExprEnc, valEncoder, false) + checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) + val stateVariableInfo = TransformWithStateVariableUtils. + getValueState(stateName, ttlEnabled = false) + stateVariableInfos.put(stateName, stateVariableInfo) null.asInstanceOf[ValueState[T]] } @@ -322,7 +337,11 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi verifyStateVarOperations("get_value_state", PRE_INIT) val colFamilySchema = StateStoreColumnFamilySchemaUtils. getValueStateSchema(stateName, keyExprEnc, valEncoder, true) + checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) + val stateVariableInfo = TransformWithStateVariableUtils. + getValueState(stateName, ttlEnabled = true) + stateVariableInfos.put(stateName, stateVariableInfo) null.asInstanceOf[ValueState[T]] } @@ -330,7 +349,11 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi verifyStateVarOperations("get_list_state", PRE_INIT) val colFamilySchema = StateStoreColumnFamilySchemaUtils. getListStateSchema(stateName, keyExprEnc, valEncoder, false) + checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) + val stateVariableInfo = TransformWithStateVariableUtils. + getListState(stateName, ttlEnabled = false) + stateVariableInfos.put(stateName, stateVariableInfo) null.asInstanceOf[ListState[T]] } @@ -341,7 +364,11 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi verifyStateVarOperations("get_list_state", PRE_INIT) val colFamilySchema = StateStoreColumnFamilySchemaUtils. getListStateSchema(stateName, keyExprEnc, valEncoder, true) + checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) + val stateVariableInfo = TransformWithStateVariableUtils. + getListState(stateName, ttlEnabled = true) + stateVariableInfos.put(stateName, stateVariableInfo) null.asInstanceOf[ListState[T]] } @@ -352,7 +379,11 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi verifyStateVarOperations("get_map_state", PRE_INIT) val colFamilySchema = StateStoreColumnFamilySchemaUtils. getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, false) + checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) + val stateVariableInfo = TransformWithStateVariableUtils. + getMapState(stateName, ttlEnabled = false) + stateVariableInfos.put(stateName, stateVariableInfo) null.asInstanceOf[MapState[K, V]] } @@ -365,6 +396,9 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi val colFamilySchema = StateStoreColumnFamilySchemaUtils. getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, true) columnFamilySchemas.put(stateName, colFamilySchema) + val stateVariableInfo = TransformWithStateVariableUtils. + getMapState(stateName, ttlEnabled = true) + stateVariableInfos.put(stateName, stateVariableInfo) null.asInstanceOf[MapState[K, V]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index d2b8f92aa985b..acdb8372a67d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -21,10 +21,6 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.json4s.JsonAST.JValue -import org.json4s.JsonDSL._ -import org.json4s.JString -import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -125,6 +121,12 @@ case class TransformWithStateExec( columnFamilySchemas } + private def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo] = { + val stateVariableInfos = getDriverProcessorHandle().getStateVariableInfos + closeProcessorHandle() + stateVariableInfos + } + /** * This method is used for the driver-side stateful processor after we * have collected all the necessary schemas. @@ -423,12 +425,12 @@ case class TransformWithStateExec( Array(StateStoreMetadataV2( StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions, stateSchemaPaths.head)) - val operatorPropertiesJson: JValue = - ("timeMode" -> JString(timeMode.toString)) ~ - ("outputMode" -> JString(outputMode.toString)) - - val json = compact(render(operatorPropertiesJson)) - OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) + val operatorProperties = TransformWithStateOperatorProperties( + timeMode.toString, + outputMode.toString, + getStateVariableInfos().values.toList + ) + OperatorStateMetadataV2(operatorInfo, stateStoreInfo, operatorProperties.json) } private def stateSchemaDirPath(): Path = { @@ -441,6 +443,23 @@ case class TransformWithStateExec( new Path(new Path(storeNamePath, "_metadata"), "schema") } + override def validateNewMetadata( + oldOperatorMetadata: OperatorStateMetadata, + newOperatorMetadata: OperatorStateMetadata): Unit = { + (oldOperatorMetadata, newOperatorMetadata) match { + case ( + oldMetadataV2: OperatorStateMetadataV2, + newMetadataV2: OperatorStateMetadataV2) => + val oldOperatorProps = TransformWithStateOperatorProperties.fromJson( + oldMetadataV2.operatorPropertiesJson) + val newOperatorProps = TransformWithStateOperatorProperties.fromJson( + newMetadataV2.operatorPropertiesJson) + TransformWithStateOperatorProperties.validateOperatorProperties( + oldOperatorProps, newOperatorProps) + case (_, _) => + } + } + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver @@ -666,4 +685,3 @@ object TransformWithStateExec { } } // scalastyle:on argcount - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala new file mode 100644 index 0000000000000..0a32564f973a3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.json4s.DefaultFormats +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods +import org.json4s.jackson.JsonMethods.{compact, render} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.StateVariableType.StateVariableType +import org.apache.spark.sql.execution.streaming.state.StateStoreErrors + +/** + * This file contains utility classes and functions for managing state variables in + * the operatorProperties field of the OperatorStateMetadata for TransformWithState. + * We use these utils to read and write state variable information for validation purposes. + */ +object TransformWithStateVariableUtils { + def getValueState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = { + TransformWithStateVariableInfo(stateName, StateVariableType.ValueState, ttlEnabled) + } + + def getListState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = { + TransformWithStateVariableInfo(stateName, StateVariableType.ListState, ttlEnabled) + } + + def getMapState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = { + TransformWithStateVariableInfo(stateName, StateVariableType.MapState, ttlEnabled) + } +} + +// Enum of possible State Variable types +object StateVariableType extends Enumeration { + type StateVariableType = Value + val ValueState, ListState, MapState = Value +} + +case class TransformWithStateVariableInfo( + stateName: String, + stateVariableType: StateVariableType, + ttlEnabled: Boolean) { + def jsonValue: JValue = { + ("stateName" -> JString(stateName)) ~ + ("stateVariableType" -> JString(stateVariableType.toString)) ~ + ("ttlEnabled" -> JBool(ttlEnabled)) + } + + def json: String = { + compact(render(jsonValue)) + } +} + +object TransformWithStateVariableInfo { + + def fromJson(json: String): TransformWithStateVariableInfo = { + implicit val formats: DefaultFormats.type = DefaultFormats + val parsed = JsonMethods.parse(json).extract[Map[String, Any]] + fromMap(parsed) + } + + def fromMap(map: Map[String, Any]): TransformWithStateVariableInfo = { + val stateName = map("stateName").asInstanceOf[String] + val stateVariableType = StateVariableType.withName( + map("stateVariableType").asInstanceOf[String]) + val ttlEnabled = map("ttlEnabled").asInstanceOf[Boolean] + TransformWithStateVariableInfo(stateName, stateVariableType, ttlEnabled) + } +} + +case class TransformWithStateOperatorProperties( + timeMode: String, + outputMode: String, + stateVariables: List[TransformWithStateVariableInfo]) { + + def json: String = { + val stateVariablesJson = stateVariables.map(_.jsonValue) + val json = + ("timeMode" -> timeMode) ~ + ("outputMode" -> outputMode) ~ + ("stateVariables" -> stateVariablesJson) + compact(render(json)) + } +} + +object TransformWithStateOperatorProperties extends Logging { + def fromJson(json: String): TransformWithStateOperatorProperties = { + implicit val formats: DefaultFormats.type = DefaultFormats + val jsonMap = JsonMethods.parse(json).extract[Map[String, Any]] + TransformWithStateOperatorProperties( + jsonMap("timeMode").asInstanceOf[String], + jsonMap("outputMode").asInstanceOf[String], + jsonMap("stateVariables").asInstanceOf[List[Map[String, Any]]].map { stateVarMap => + TransformWithStateVariableInfo.fromMap(stateVarMap) + } + ) + } + + // This function is to confirm that the operator properties and state variables have + // only changed in an acceptable way after query restart. If the properties have changed + // in an unacceptable way, this function will throw an exception. + def validateOperatorProperties( + oldOperatorProperties: TransformWithStateOperatorProperties, + newOperatorProperties: TransformWithStateOperatorProperties): Unit = { + if (oldOperatorProperties.timeMode != newOperatorProperties.timeMode) { + throw StateStoreErrors.invalidConfigChangedAfterRestart( + "timeMode", oldOperatorProperties.timeMode, newOperatorProperties.timeMode) + } + + if (oldOperatorProperties.outputMode != newOperatorProperties.outputMode) { + throw StateStoreErrors.invalidConfigChangedAfterRestart( + "outputMode", oldOperatorProperties.outputMode, newOperatorProperties.outputMode) + } + + val oldStateVariableInfos = oldOperatorProperties.stateVariables + val newStateVariableInfos = newOperatorProperties.stateVariables.map { stateVarInfo => + stateVarInfo.stateName -> stateVarInfo + }.toMap + oldStateVariableInfos.foreach { oldInfo => + val newInfo = newStateVariableInfos.get(oldInfo.stateName) + newInfo match { + case Some(stateVarInfo) => + if (oldInfo.stateVariableType != stateVarInfo.stateVariableType) { + throw StateStoreErrors.invalidVariableTypeChange( + stateVarInfo.stateName, + oldInfo.stateVariableType.toString, + stateVarInfo.stateVariableType.toString + ) + } + case None => + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index 4ac813291c00b..e4b370e67b018 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -173,8 +173,51 @@ object StateStoreErrors { StateStoreProviderDoesNotSupportFineGrainedReplay = { new StateStoreProviderDoesNotSupportFineGrainedReplay(inputClass) } + + def invalidConfigChangedAfterRestart(configName: String, oldConfig: String, newConfig: String): + StateStoreInvalidConfigAfterRestart = { + new StateStoreInvalidConfigAfterRestart(configName, oldConfig, newConfig) + } + + def duplicateStateVariableDefined(stateName: String): + StateStoreDuplicateStateVariableDefined = { + new StateStoreDuplicateStateVariableDefined(stateName) + } + + def invalidVariableTypeChange(stateName: String, oldType: String, newType: String): + StateStoreInvalidVariableTypeChange = { + new StateStoreInvalidVariableTypeChange(stateName, oldType, newType) + } } +class StateStoreDuplicateStateVariableDefined(stateVarName: String) + extends SparkRuntimeException( + errorClass = "STATEFUL_PROCESSOR_DUPLICATE_STATE_VARIABLE_DEFINED", + messageParameters = Map( + "stateVarName" -> stateVarName + ) + ) + +class StateStoreInvalidConfigAfterRestart(configName: String, oldConfig: String, newConfig: String) + extends SparkUnsupportedOperationException( + errorClass = "STATE_STORE_INVALID_CONFIG_AFTER_RESTART", + messageParameters = Map( + "configName" -> configName, + "oldConfig" -> oldConfig, + "newConfig" -> newConfig + ) + ) + +class StateStoreInvalidVariableTypeChange(stateVarName: String, oldType: String, newType: String) + extends SparkUnsupportedOperationException( + errorClass = "STATE_STORE_INVALID_VARIABLE_TYPE_CHANGE", + messageParameters = Map( + "stateVarName" -> stateVarName, + "oldType" -> oldType, + "newType" -> newType + ) + ) + class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String) extends SparkUnsupportedOperationException( errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_MULTIPLE_COLUMN_FAMILIES", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 43d75c4b4d137..6937bbadfab5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -276,6 +276,10 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp /** Name to output in [[StreamingOperatorProgress]] to identify operator type */ def shortName: String = "defaultName" + def validateNewMetadata( + oldMetadata: OperatorStateMetadata, + newMetadata: OperatorStateMetadata): Unit = {} + /** * Should the MicroBatchExecution run another batch based on this stateful operator and the * new input watermark. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index d55a16a60eac0..5b91e0de1903f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.streaming import java.io.File +import java.time.Duration import java.util.UUID import org.apache.hadoop.fs.Path -import org.apache.spark.SparkRuntimeException +import org.apache.spark.{SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Encoders, Row} import org.apache.spark.sql.catalyst.util.stringToFile @@ -63,6 +64,57 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S } } +class RunningCountStatefulProcessorWithTTL + extends StatefulProcessor[String, String, (String, String)] + with Logging { + @transient protected var _countState: ValueState[Long] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _countState = getHandle.getValueState[Long]("countState", + Encoders.scalaLong, TTLConfig(Duration.ofMillis(1000))) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + val count = _countState.getOption().getOrElse(0L) + 1 + if (count == 3) { + _countState.clear() + Iterator.empty + } else { + _countState.update(count) + Iterator((key, count.toString)) + } + } +} + +// Class to test that changing between Value and List State fails +// between query runs +class RunningCountListStatefulProcessor + extends StatefulProcessor[String, String, (String, String)] + with Logging { + @transient protected var _countState: ListState[Long] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _countState = getHandle.getListState[Long]( + "countState", Encoders.scalaLong) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + Iterator.empty + } +} + class RunningCountStatefulProcessorInt extends StatefulProcessor[String, String, (String, String)] { @transient protected var _countState: ValueState[Int] = _ @@ -939,10 +991,20 @@ class TransformWithStateSuite extends StateStoreMetricsTest ) val df = spark.read.format("state-metadata").load(checkpointDir.toString) - checkAnswer(df, Seq( - Row(0, "transformWithStateExec", "default", 5, 0L, 1L, - """{"timeMode":"NoTime","outputMode":"Update"}""") - )) + + // check first 6 columns of the row, and then read the last column of the row separately + checkAnswer( + df.select( + "operatorId", "operatorName", "stateStoreName", "numPartitions", "minBatchId", + "maxBatchId"), + Seq(Row(0, "transformWithStateExec", "default", 5, 0, 1)) + ) + val operatorPropsJson = df.select("operatorProperties").collect().head.getString(0) + val operatorProperties = TransformWithStateOperatorProperties.fromJson(operatorPropsJson) + assert(operatorProperties.timeMode == "NoTime") + assert(operatorProperties.outputMode == "Update") + assert(operatorProperties.stateVariables.length == 1) + assert(operatorProperties.stateVariables.head.stateName == "countState") } } } @@ -983,6 +1045,192 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } } + + test("test that different outputMode after query restart fails") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Append()) + + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + ExpectFailure[StateStoreInvalidConfigAfterRestart] { e => + checkError( + e.asInstanceOf[SparkUnsupportedOperationException], + errorClass = "STATE_STORE_INVALID_CONFIG_AFTER_RESTART", + parameters = Map( + "configName" -> "outputMode", + "oldConfig" -> "Update", + "newConfig" -> "Append") + ) + } + ) + } + } + } + + test("test that changing between different state variable types fails") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountListStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + ExpectFailure[StateStoreInvalidVariableTypeChange] { t => + checkError( + t.asInstanceOf[SparkUnsupportedOperationException], + errorClass = "STATE_STORE_INVALID_VARIABLE_TYPE_CHANGE", + parameters = Map( + "stateVarName" -> "countState", + "newType" -> "ListState", + "oldType" -> "ValueState") + ) + } + ) + } + } + } + + test("test that different timeMode after query restart fails") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val clock = new StreamManualClock + val inputData = MemoryStream[String] + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream( + checkpointLocation = checkpointDir.getCanonicalPath, + trigger = Trigger.ProcessingTime("1 second"), + triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + ExpectFailure[StateStoreInvalidConfigAfterRestart] { e => + checkError( + e.asInstanceOf[SparkUnsupportedOperationException], + errorClass = "STATE_STORE_INVALID_CONFIG_AFTER_RESTART", + parameters = Map( + "configName" -> "timeMode", + "oldConfig" -> "NoTime", + "newConfig" -> "ProcessingTime") + ) + } + ) + } + } + } + + test("test that introducing TTL after restart fails query") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val clock = new StreamManualClock + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream( + trigger = Trigger.ProcessingTime("1 second"), + checkpointLocation = checkpointDir.getCanonicalPath, + triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + AdvanceManualClock(1 * 1000), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorWithTTL(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream( + trigger = Trigger.ProcessingTime("1 second"), + checkpointLocation = checkpointDir.getCanonicalPath, + triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + ExpectFailure[StateStoreValueSchemaNotCompatible] { t => + checkError( + t.asInstanceOf[SparkUnsupportedOperationException], + errorClass = "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE", + parameters = Map( + "storedValueSchema" -> "StructType(StructField(value,LongType,false))", + "newValueSchema" -> + ("StructType(StructField(value,StructType(StructField(value,LongType,false))," + + "true),StructField(ttlExpirationMs,LongType,true))") + ) + ) + } + ) + } + } + } } class TransformWithStateValidationSuite extends StateStoreMetricsTest {