diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 492a33c57461..723efbea715e 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5493,6 +5493,11 @@ "message" : [ "Unsupported offset sequence version . Please make sure the checkpoint is from a supported Spark version (Spark 4.0+)." ] + }, + "UNSUPPORTED_PROVIDER" : { + "message" : [ + " is not supported" + ] } }, "sqlState" : "55019" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 6af418e1ddc2..f49ced7a1c22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -41,7 +41,8 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.transformwith import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata -import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} +import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} +import org.apache.spark.sql.execution.streaming.state.OfflineStateRepartitionErrors import org.apache.spark.sql.execution.streaming.utils.StreamingUtils import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.streaming.TimeMode @@ -66,6 +67,14 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf, StateSourceOptions.apply(session, hadoopConf, properties)) val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId) + // We only support RocksDB because the repartition work that this option + // is built for only supports RocksDB + if (sourceOptions.internalOnlyReadAllColumnFamilies + && stateConf.providerClass != classOf[RocksDBStateStoreProvider].getName) { + throw OfflineStateRepartitionErrors.unsupportedStateStoreProviderError( + sourceOptions.resolvedCpLocation, + stateConf.providerClass) + } val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks( sourceOptions) @@ -372,6 +381,7 @@ case class StateSourceOptions( stateVarName: Option[String], readRegisteredTimers: Boolean, flattenCollectionTypes: Boolean, + internalOnlyReadAllColumnFamilies: Boolean = false, startOperatorStateUniqueIds: Option[Array[Array[String]]] = None, endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) @@ -379,8 +389,9 @@ case class StateSourceOptions( override def toString: String = { var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " + s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " + - s"stateVarName=${stateVarName.getOrElse("None")}, +" + - s"flattenCollectionTypes=$flattenCollectionTypes" + s"stateVarName=${stateVarName.getOrElse("None")}, " + + s"flattenCollectionTypes=$flattenCollectionTypes, " + + s"internalOnlyReadAllColumnFamilies=$internalOnlyReadAllColumnFamilies" if (fromSnapshotOptions.isDefined) { desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}" desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}" @@ -407,6 +418,7 @@ object StateSourceOptions extends DataSourceOptions { val STATE_VAR_NAME = newOption("stateVarName") val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers") val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes") + val INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = newOption("internalOnlyReadAllColumnFamilies") object JoinSideValues extends Enumeration { type JoinSideValues = Value @@ -478,6 +490,7 @@ object StateSourceOptions extends DataSourceOptions { s"Valid values are ${JoinSideValues.values.mkString(",")}") } + // Use storeName rather than joinSide to identify the specific join store if (joinSide != JoinSideValues.none && storeName != StateStoreId.DEFAULT_STORE_NAME) { throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME)) } @@ -492,6 +505,29 @@ object StateSourceOptions extends DataSourceOptions { val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean) + val internalOnlyReadAllColumnFamilies = try { + Option(options.get(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES)).exists(_.toBoolean) + } catch { + case _: IllegalArgumentException => + throw StateDataSourceErrors.invalidOptionValue(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, + "Boolean value is expected") + } + + if (internalOnlyReadAllColumnFamilies && stateVarName.isDefined) { + throw StateDataSourceErrors.conflictOptions( + Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, STATE_VAR_NAME)) + } + + if (internalOnlyReadAllColumnFamilies && joinSide != JoinSideValues.none) { + throw StateDataSourceErrors.conflictOptions( + Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, JOIN_SIDE)) + } + + if (internalOnlyReadAllColumnFamilies && readChangeFeed) { + throw StateDataSourceErrors.conflictOptions( + Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, READ_CHANGE_FEED)) + } + val changeStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong) var changeEndBatchId = Option(options.get(CHANGE_END_BATCH_ID)).map(_.toLong) @@ -615,7 +651,7 @@ object StateSourceOptions extends DataSourceOptions { StateSourceOptions( resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, - stateVarName, readRegisteredTimers, flattenCollectionTypes, + stateVarName, readRegisteredTimers, flattenCollectionTypes, internalOnlyReadAllColumnFamilies, startOperatorStateUniqueIds, endOperatorStateUniqueIds) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 619e374c00de..9fc3c081173f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -49,7 +49,10 @@ class StatePartitionReaderFactory( override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition] - if (stateStoreInputPartition.sourceOptions.readChangeFeed) { + if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) { + new StatePartitionAllColumnFamiliesReader(storeConf, hadoopConf, + stateStoreInputPartition, schema, keyStateEncoderSpec, stateStoreColFamilySchemaOpt) + } else if (stateStoreInputPartition.sourceOptions.readChangeFeed) { new StateStoreChangeDataPartitionReader(storeConf, hadoopConf, stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt) @@ -81,16 +84,22 @@ abstract class StatePartitionReaderBase( private val schemaForValueRow: StructType = StructType(Array(StructField("__dummy__", NullType))) - protected val keySchema = { + protected val keySchema : StructType = { if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) { SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions) + } else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) { + require(stateStoreColFamilySchemaOpt.isDefined) + stateStoreColFamilySchemaOpt.map(_.keySchema).get } else { SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] } } - protected val valueSchema = if (stateVariableInfoOpt.isDefined) { + protected val valueSchema : StructType = if (stateVariableInfoOpt.isDefined) { schemaForValueRow + } else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) { + require(stateStoreColFamilySchemaOpt.isDefined) + stateStoreColFamilySchemaOpt.map(_.valueSchema).get } else { SchemaUtil.getSchemaAsDataType( schema, "value").asInstanceOf[StructType] @@ -237,6 +246,48 @@ class StatePartitionReader( } } +/** + * An implementation of [[StatePartitionReaderBase]] for reading all column families + * in binary format. This reader returns raw key and value bytes along with column family names. + * We are returning key/value bytes because each column family can have different schema + * It will also return the partition key + */ +class StatePartitionAllColumnFamiliesReader( + storeConf: StateStoreConf, + hadoopConf: SerializableConfiguration, + partition: StateStoreInputPartition, + schema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]) + extends StatePartitionReaderBase( + storeConf, + hadoopConf, partition, schema, + keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) { + + private lazy val store: ReadStateStore = { + assert(getStartStoreUniqueId == getEndStoreUniqueId, + "Start and end store unique IDs must be the same when reading all column families") + provider.getReadStore( + partition.sourceOptions.batchId + 1, + getStartStoreUniqueId + ) + } + + override lazy val iter: Iterator[InternalRow] = { + store + .iterator() + .map { pair => + SchemaUtil.unifyStateRowPairAsRawBytes( + (pair.key, pair.value), StateStore.DEFAULT_COL_FAMILY_NAME) + } + } + + override def close(): Unit = { + store.release() + super.close() + } +} + /** * An implementation of [[StatePartitionReaderBase]] for the readChangeFeed mode of State Data * Source. It reads the change of state over batches of a particular partition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionWriter.scala new file mode 100644 index 000000000000..1cae11147278 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionWriter.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.io.IOException + +import scala.collection.MapView +import scala.collection.immutable.HashMap + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil +import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, StateStore, StateStoreConf, StateStoreId, StateStoreProvider, StateStoreProviderId} +import org.apache.spark.sql.types.StructType + +/** + * A writer that can directly write binary data to the streaming state store. + * + * This writer expects input rows with the same schema produced by + * StatePartitionAllColumnFamiliesReader: + * (partition_key, key_bytes, value_bytes, column_family_name) + * + * The writer creates a fresh (empty) state store instance for the target commit version + * instead of loading previous partition data. After writing all rows for the partition, it will + * commit all changes as a snapshot + */ +class StatePartitionAllColumnFamiliesWriter( + storeConf: StateStoreConf, + hadoopConf: Configuration, + partition: StateStoreInputPartition, + columnFamilyToSchemaMap: HashMap[String, (StructType, StructType)], + keyStateEncoderSpec: KeyStateEncoderSpec) { + + private val (defaultKeySchema, defaultValueSchema) = { + columnFamilyToSchemaMap.getOrElse( + StateStore.DEFAULT_COL_FAMILY_NAME, + throw new IllegalArgumentException( + s"Column family ${StateStore.DEFAULT_COL_FAMILY_NAME} not found in schema map") + ) + } + + private val columnFamilyToKeySchemaLenMap: MapView[String, Int] = + columnFamilyToSchemaMap.view.mapValues(_._1.length) + private val columnFamilyToValueSchemaLenMap: MapView[String, Int] = + columnFamilyToSchemaMap.view.mapValues(_._2.length) + + private val rowConverter = { + val schema = SchemaUtil.getSourceSchema( + partition.sourceOptions, defaultKeySchema, defaultValueSchema, None, None) + CatalystTypeConverters.createToCatalystConverter(schema) + } + + protected lazy val provider: StateStoreProvider = { + val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString, + partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName) + val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) + + val provider = StateStoreProvider.createAndInit( + stateStoreProviderId, defaultKeySchema, defaultValueSchema, keyStateEncoderSpec, + useColumnFamilies = false, storeConf, hadoopConf, + useMultipleValuesPerKey = false, stateSchemaProvider = None) + provider + } + + private lazy val stateStore: StateStore = { + val stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) { + Some(java.util.UUID.randomUUID.toString) + } else { + None + } + val version = partition.sourceOptions.batchId + 1 + // Create empty store to avoid loading old partition data during repartitioning + // Use loadEmpty=true to create a fresh state store without loading previous versions + // We create the empty store AT version, and the next commit will + // produce version + 1 + provider.getStore( + version, + stateStoreCkptId, + forceSnapshotOnCommit = true, + loadEmpty = true + ) + } + + // The function that writes and commits data to state store. It takes in rows with schema + // - partition_key, StructType + // - key_bytes, BinaryType + // - value_bytes, BinaryType + // - column_family_name, StringType + def put(partition: Iterator[Row]): Unit = { + partition.foreach(row => putRaw(row)) + stateStore.commit() + } + + private def putRaw(rawRecord: Row): Unit = { + val record = rowConverter(rawRecord).asInstanceOf[InternalRow] + // Validate record schema + if (record.numFields != 4) { + throw new IOException( + s"Invalid record schema: expected 4 fields (partition_key, key_bytes, value_bytes, " + + s"column_family_name), got ${record.numFields}") + } + + // Extract raw bytes and column family name from the record + val keyBytes = record.getBinary(1) + val valueBytes = record.getBinary(2) + val colFamilyName = record.getString(3) + + // Reconstruct UnsafeRow objects from the raw bytes + // The bytes are in UnsafeRow memory format from StatePartitionReaderAllColumnFamilies + val keyRow = new UnsafeRow(columnFamilyToKeySchemaLenMap(colFamilyName)) + keyRow.pointTo(keyBytes, keyBytes.length) + + val valueRow = new UnsafeRow(columnFamilyToValueSchemaLenMap(colFamilyName)) + valueRow.pointTo(valueBytes, valueBytes.length) + + // Use StateStore API which handles proper RocksDB encoding (version byte, checksums, etc.) + stateStore.put(keyRow, valueRow, colFamilyName) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala index 52df016791d4..44d83fc99b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceError import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateVariableType._ import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStoreColFamilySchema, UnsafeRowPair} -import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, MapType, StringType, StructType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, IntegerType, LongType, MapType, StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ object SchemaUtil { @@ -60,6 +61,14 @@ object SchemaUtil { .add("key", keySchema) .add("value", valueSchema) .add("partition_id", IntegerType) + } else if (sourceOptions.internalOnlyReadAllColumnFamilies) { + new StructType() + // todo [SPARK-54443]: change keySchema to a more specific type after we + // can extract partition key from keySchema + .add("partition_key", keySchema) + .add("key_bytes", BinaryType) + .add("value_bytes", BinaryType) + .add("column_family_name", StringType) } else { new StructType() .add("key", keySchema) @@ -76,6 +85,26 @@ object SchemaUtil { row } + /** + * Returns an InternalRow representing + * 1. partitionKey + * 2. key in bytes + * 3. value in bytes + * 4. column family name + */ + def unifyStateRowPairAsRawBytes( + pair: (UnsafeRow, UnsafeRow), + colFamilyName: String): InternalRow = { + val row = new GenericInternalRow(4) + // todo [SPARK-54443]: change keySchema to more specific type after we + // can extract partition key from keySchema + row.update(0, pair._1) + row.update(1, pair._1.getBytes) + row.update(2, pair._2.getBytes) + row.update(3, UTF8String.fromString(colFamilyName)) + row + } + def unifyStateRowPairWithMultipleValues( pair: (UnsafeRow, GenericArrayData), partition: Int): InternalRow = { @@ -231,7 +260,11 @@ object SchemaUtil { "user_map_key" -> classOf[StructType], "user_map_value" -> classOf[StructType], "expiration_timestamp_ms" -> classOf[LongType], - "partition_id" -> classOf[IntegerType]) + "partition_id" -> classOf[IntegerType], + "partition_key" -> classOf[StructType], + "key_bytes" -> classOf[BinaryType], + "value_bytes" -> classOf[BinaryType], + "column_family_name" -> classOf[StringType]) val expectedFieldNames = if (transformWithStateVariableInfoOpt.isDefined) { val stateVarInfo = transformWithStateVariableInfoOpt.get @@ -272,6 +305,8 @@ object SchemaUtil { } } else if (sourceOptions.readChangeFeed) { Seq("batch_id", "change_type", "key", "value", "partition_id") + } else if (sourceOptions.internalOnlyReadAllColumnFamilies) { + Seq("partition_key", "key_bytes", "value_bytes", "column_family_name") } else { Seq("key", "value", "partition_id") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 9da75a9728dd..6ae9e89229c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -322,12 +322,17 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with override def getStore( version: Long, uniqueId: Option[String] = None, - forceSnapshotOnCommit: Boolean = false): StateStore = { + forceSnapshotOnCommit: Boolean = false, + loadEmpty: Boolean = false): StateStore = { if (uniqueId.isDefined) { throw StateStoreErrors.stateStoreCheckpointIdsNotSupported( "HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1 " + "but a state store checkpointID is passed in") } + if (loadEmpty) { + throw StateStoreErrors.unsupportedOperationException("getStore", + "loadEmpty parameter is not supported in HDFSBackedStateStoreProvider") + } val newMap = getLoadedMapForStore(version) logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, version)} " + log"of ${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala index 95b273826877..0e9b8ad8a63b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala @@ -85,6 +85,12 @@ object OfflineStateRepartitionErrors { version: Int): StateRepartitionInvalidCheckpointError = { new StateRepartitionUnsupportedOffsetSeqVersionError(checkpointLocation, version) } + + def unsupportedStateStoreProviderError( + checkpointLocation: String, + providerClass: String): StateRepartitionInvalidCheckpointError = { + new StateRepartitionUnsupportedProviderError(checkpointLocation, providerClass) + } } /** @@ -201,3 +207,11 @@ class StateRepartitionUnsupportedOffsetSeqVersionError( checkpointLocation, subClass = "UNSUPPORTED_OFFSET_SEQ_VERSION", messageParameters = Map("version" -> version.toString)) + +class StateRepartitionUnsupportedProviderError( + checkpointLocation: String, + provider: String) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "UNSUPPORTED_PROVIDER", + messageParameters = Map("provider" -> provider)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 8a2ed6d9a529..024fc6b0241a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -731,6 +731,91 @@ class RocksDB( this } + /** + * Create an empty RocksDB state store at the specified version without loading previous data. + * + * This method is useful when state will be completely rewritten and + * does not need to load previous states + * + * @param targetVersion The version to initialize the empty store at (must be >= 0) + * @param stateStoreCkptId Optional checkpoint ID (required if checkpoint IDs are enabled) + * @param readOnly Whether to open the store in read-only mode + * @return A RocksDB instance with an empty state at the target version + */ + def loadEmpty( + targetVersion: Long, + stateStoreCkptId: Option[String] = None, + readOnly: Boolean = false): RocksDB = { + + assert(targetVersion >= 0, s"Target version must be >= 0, got $targetVersion") + recordedMetrics = None + loadMetrics.clear() + + logInfo(log"Creating empty store at version ${MDC(LogKeys.VERSION_NUM, targetVersion)} " + + log"with stateStoreCkptId: ${MDC(LogKeys.UUID, stateStoreCkptId.getOrElse(""))}") + + try { + closeDB(ignoreException = false) + + // Use version 0 logic to create empty directory with no SST files + val metadata = fileManager.loadCheckpointFromDfs(0, workingDir, rocksDBFileMapping, None) + + // Set version tracking to target version + loadedVersion = targetVersion + + // Handle checkpoint IDs if enabled + if (enableStateStoreCheckpointIds) { + require(stateStoreCkptId.isDefined, + "stateStoreCkptId must be defined when checkpoint IDs are enabled") + + loadedStateStoreCkptId = stateStoreCkptId + sessionStateStoreCkptId = Some(java.util.UUID.randomUUID.toString) + lastCommitBasedStateStoreCkptId = None + lastCommittedStateStoreCkptId = None + + // Clear lineage - targetVersion has no checkpoint, so no dependencies + lineageManager.clear() + } + + // Initialize maxVersion to target version + fileManager.setMaxSeenVersion(targetVersion) + + openLocalRocksDB(metadata) + + // Empty store has no keys + numKeysOnLoadedVersion = 0 + numInternalKeysOnLoadedVersion = 0 + + // Initialize changelog writer for next version with empty lineage + if (enableChangelogCheckpointing && !readOnly) { + changelogWriter.foreach(_.abort()) + + // Empty lineage since this is a fresh start with forced snapshot + changelogWriter = Some(fileManager.getChangeLogWriter( + targetVersion + 1, + useColumnFamilies, + sessionStateStoreCkptId, + if (enableStateStoreCheckpointIds) Some(Array.empty[LineageItem]) else None + )) + } + + logInfo(log"Created empty store at version ${MDC(LogKeys.VERSION_NUM, targetVersion)}") + } catch { + case t: Throwable => + loadedVersion = -1 // invalidate loaded data + if (enableStateStoreCheckpointIds) { + lastCommitBasedStateStoreCkptId = None + lastCommittedStateStoreCkptId = None + loadedStateStoreCkptId = None + sessionStateStoreCkptId = None + lineageManager.clear() + } + throw t + } + + this + } + /** * Load from the start snapshot version and apply all the changelog records to reach the * end version. Note that this will copy all the necessary files from DFS to local disk as needed, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 96652ffe6fd7..a8b8ff727246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -726,6 +726,7 @@ private[sql] class RocksDBStateStoreProvider * @param readOnly Whether to open the store in read-only mode * @param existingStore Optional existing store to reuse instead of creating a new one * @param forceSnapshotOnCommit Whether to force a snapshot upload on commit + * @param loadEmpty If true, creates an empty store at this version without loading previous data * @return The loaded state store */ private def loadStateStore( @@ -733,7 +734,8 @@ private[sql] class RocksDBStateStoreProvider uniqueId: Option[String] = None, readOnly: Boolean, existingStore: Option[RocksDBStateStore] = None, - forceSnapshotOnCommit: Boolean = false): StateStore = { + forceSnapshotOnCommit: Boolean = false, + loadEmpty: Boolean = false): StateStore = { var acquiredStamp: Option[Long] = None var storeLoaded = false try { @@ -762,10 +764,18 @@ private[sql] class RocksDBStateStoreProvider Some(s) } - rocksDB.load( - version, - stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None, - readOnly = readOnly) + // Load RocksDB: either empty or from existing checkpoints + if (loadEmpty) { + rocksDB.loadEmpty( + targetVersion = version, + stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None, + readOnly = readOnly) + } else { + rocksDB.load( + version, + stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None, + readOnly = readOnly) + } // Create or reuse store instance val store = existingStore match { @@ -806,12 +816,14 @@ private[sql] class RocksDBStateStoreProvider override def getStore( version: Long, uniqueId: Option[String] = None, - forceSnapshotOnCommit: Boolean = false): StateStore = { + forceSnapshotOnCommit: Boolean = false, + loadEmpty: Boolean = false): StateStore = { loadStateStore( version, uniqueId, readOnly = false, - forceSnapshotOnCommit = forceSnapshotOnCommit + forceSnapshotOnCommit = if (loadEmpty) true else forceSnapshotOnCommit, + loadEmpty = loadEmpty ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 43b95766882f..bd6b4bede84b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -665,11 +665,13 @@ trait StateStoreProvider { /** * Return an instance of [[StateStore]] representing state data of the given version. * If `stateStoreCkptId` is provided, the instance also needs to match the ID. + * If `loadEmpty` is true, creates an empty store at this version without loading previous data. * */ def getStore( version: Long, stateStoreCkptId: Option[String] = None, - forceSnapshotOnCommit: Boolean = false): StateStore + forceSnapshotOnCommit: Boolean = false, + loadEmpty: Boolean = false): StateStore /** * Return an instance of [[ReadStateStore]] representing state data of the given version diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala index 64d005c719b7..b66408bb7d69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala @@ -78,7 +78,7 @@ trait StateDataSourceTestBase extends StreamTest with StateStoreMetricsTest { ) } - private def getCompositeKeyStreamingAggregationQuery( + protected def getCompositeKeyStreamingAggregationQuery( inputData: MemoryStream[Int]): Dataset[(Int, String, Long, Long, Int, Int)] = { inputData.toDF() .selectExpr("value", "value % 2 AS groupKey", @@ -140,7 +140,7 @@ trait StateDataSourceTestBase extends StreamTest with StateStoreMetricsTest { ) } - private def getLargeDataStreamingAggregationQuery( + protected def getLargeDataStreamingAggregationQuery( inputData: MemoryStream[Int]): Dataset[(Int, Long, Long, Int, Int)] = { inputData.toDF() .selectExpr("value", "value % 10 AS groupKey") @@ -179,7 +179,7 @@ trait StateDataSourceTestBase extends StreamTest with StateStoreMetricsTest { ) } - private def getDropDuplicatesQuery(inputData: MemoryStream[Int]): Dataset[Long] = { + protected def getDropDuplicatesQuery(inputData: MemoryStream[Int]): Dataset[Long] = { inputData.toDS() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") @@ -204,7 +204,7 @@ trait StateDataSourceTestBase extends StreamTest with StateStoreMetricsTest { ) } - private def getDropDuplicatesQueryWithColumnSpecified( + protected def getDropDuplicatesQueryWithColumnSpecified( inputData: MemoryStream[(String, Int)]): Dataset[(String, Int)] = { inputData.toDS() .selectExpr("_1 AS col1", "_2 AS col2") @@ -256,7 +256,7 @@ trait StateDataSourceTestBase extends StreamTest with StateStoreMetricsTest { ) } - private def getDropDuplicatesWithinWatermarkQuery( + protected def getDropDuplicatesWithinWatermarkQuery( inputData: MemoryStream[(String, Int)]): DataFrame = { inputData.toDS() .withColumn("eventTime", timestamp_seconds($"_2")) @@ -293,7 +293,7 @@ trait StateDataSourceTestBase extends StreamTest with StateStoreMetricsTest { ) } - private def getFlatMapGroupsWithStateQuery( + protected def getFlatMapGroupsWithStateQuery( inputData: MemoryStream[(String, Long)]): Dataset[(String, Int, Long, Boolean)] = { // scalastyle:off line.size.limit // This test code is borrowed from Sessionization example, with modification a bit to run with testStream @@ -405,8 +405,7 @@ trait StateDataSourceTestBase extends StreamTest with StateStoreMetricsTest { col("rightId"), col("rightTime").cast("int")) } - protected def runSessionWindowAggregationQuery(checkpointRoot: String): Unit = { - val input = MemoryStream[(String, Long)] + protected def getSessionWindowAggregationQuery(input: MemoryStream[(String, Long)]): DataFrame = { val sessionWindow = session_window($"eventTime", "10 seconds") val events = input.toDF() @@ -415,13 +414,17 @@ trait StateDataSourceTestBase extends StreamTest with StateStoreMetricsTest { .withWatermark("eventTime", "30 seconds") .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") - val streamingDf = events + events .groupBy(sessionWindow as Symbol("session"), $"sessionId") .agg(count("*").as("numEvents")) .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", "numEvents") + } + protected def runSessionWindowAggregationQuery(checkpointRoot: String): Unit = { + val input = MemoryStream[(String, Long)] + val streamingDf = getSessionWindowAggregationQuery(input) testStream(streamingDf, OutputMode.Complete())( StartStream(checkpointLocation = checkpointRoot), AddData(input, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala new file mode 100644 index 000000000000..c4b59b149b96 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala @@ -0,0 +1,536 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.nio.ByteOrder +import java.util.Arrays + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateRepartitionUnsupportedProviderError, StateStore} +import org.apache.spark.sql.functions.{count, sum} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, NullType, StructField, StructType, TimestampType} + +/** + * Note: This extends StateDataSourceTestBase to access + * helper methods like runDropDuplicatesQuery without inheriting all predefined tests. + */ +class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase { + + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key, + classOf[RocksDBStateStoreProvider].getName) + } + + private def getNormalReadDf( + checkpointDir: String, + storeName: Option[String] = Option.empty[String]): DataFrame = { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDir) + .option(StateSourceOptions.STORE_NAME, storeName.orNull) + .load() + .selectExpr("partition_id", "key", "value") + } + + private def getBytesReadDf( + checkpointDir: String, + storeName: Option[String] = Option.empty[String]): DataFrame = { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDir) + .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") + .option(StateSourceOptions.STORE_NAME, storeName.orNull) + .load() + } + + /** + * Validates the schema and column families of the bytes read DataFrame. + */ + private def validateBytesReadDfSchema(df: DataFrame): Unit = { + // Verify schema + val schema = df.schema + assert(schema.fieldNames === Array( + "partition_key", "key_bytes", "value_bytes", "column_family_name")) + assert(schema("partition_key").dataType.typeName === "struct") + assert(schema("key_bytes").dataType.typeName === "binary") + assert(schema("value_bytes").dataType.typeName === "binary") + assert(schema("column_family_name").dataType.typeName === "string") + } + + /** + * Compares normal read data with bytes read data for a specific column family. + * Converts normal rows to bytes then compares with bytes read. + */ + private def compareNormalAndBytesData( + normalDf: Array[Row], + bytesDf: Array[Row], + columnFamily: String, + keySchema: StructType, + valueSchema: StructType): Unit = { + + // Filter bytes data for the specified column family and extract raw bytes directly + val filteredBytesData = bytesDf.filter { row => + row.getString(3) == columnFamily + } + + // Verify same number of rows + assert(filteredBytesData.length == normalDf.length, + s"Row count mismatch for column family '$columnFamily': " + + s"normal read has ${normalDf.length} rows, " + + s"bytes read has ${filteredBytesData.length} rows") + + // Create projections to convert Row to UnsafeRow bytes + val keyProjection = UnsafeProjection.create(keySchema) + val valueProjection = UnsafeProjection.create(valueSchema) + + // Create converters to convert external Row types to internal Catalyst types + val keyConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) + val valueConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema) + + // Convert normal data to bytes + val normalAsBytes = normalDf.toSeq.map { row => + val key = row.getStruct(1) + val value = if (row.isNullAt(2)) null else row.getStruct(2) + + // Convert key to InternalRow, then to UnsafeRow, then get bytes + val keyInternalRow = keyConverter(key).asInstanceOf[InternalRow] + val keyUnsafeRow = keyProjection(keyInternalRow) + // IMPORTANT: Must clone the bytes array since getBytes() returns a reference + // that may be overwritten by subsequent UnsafeRow operations + val keyBytes = keyUnsafeRow.getBytes.clone() + + // Convert value to bytes + val valueBytes = if (value == null) { + Array.empty[Byte] + } else { + val valueInternalRow = valueConverter(value).asInstanceOf[InternalRow] + val valueUnsafeRow = valueProjection(valueInternalRow) + // IMPORTANT: Must clone the bytes array since getBytes() returns a reference + // that may be overwritten by subsequent UnsafeRow operations + valueUnsafeRow.getBytes.clone() + } + + (keyBytes, valueBytes) + } + + // Extract raw bytes from bytes read data (no deserialization/reserialization) + val bytesAsBytes = filteredBytesData.map { row => + val keyBytes = row.getAs[Array[Byte]](1) + val valueBytes = row.getAs[Array[Byte]](2) + (keyBytes, valueBytes) + } + + // Sort both for comparison (since Set equality doesn't work well with byte arrays) + val normalSorted = normalAsBytes.sortBy(x => (x._1.mkString(","), x._2.mkString(","))) + val bytesSorted = bytesAsBytes.sortBy(x => (x._1.mkString(","), x._2.mkString(","))) + + assert(normalSorted.length == bytesSorted.length, + s"Size mismatch: normal has ${normalSorted.length}, bytes has ${bytesSorted.length}") + + // Compare each pair + normalSorted.zip(bytesSorted).zipWithIndex.foreach { + case (((normalKey, normalValue), (bytesKey, bytesValue)), idx) => + assert(Arrays.equals(normalKey, bytesKey), + s"Key mismatch at index $idx:\n" + + s" Normal: ${normalKey.mkString("[", ",", "]")}\n" + + s" Bytes: ${bytesKey.mkString("[", ",", "]")}") + assert(Arrays.equals(normalValue, bytesValue), + s"Value mismatch at index $idx:\n" + + s" Normal: ${normalValue.mkString("[", ",", "]")}\n" + + s" Bytes: ${bytesValue.mkString("[", ",", "]")}") + } + } + + // Run all tests with both changelog checkpointing enabled and disabled + Seq(true, false).foreach { changelogCheckpointingEnabled => + val testSuffix = if (changelogCheckpointingEnabled) { + "with changelog checkpointing" + } else { + "without changelog checkpointing" + } + + def testWithChangelogConfig(testName: String)(testFun: => Unit): Unit = { + test(s"$testName ($testSuffix)") { + withSQLConf( + "spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" -> + changelogCheckpointingEnabled.toString) { + testFun + } + } + } + + testWithChangelogConfig("SPARK-54388: simple aggregation state ver 1") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1") { + withTempDir { tempDir => + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array(StructField("groupKey", IntegerType, nullable = false))) + // State version 1 includes key columns in the value + val valueSchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("SPARK-54388: simple aggregation state ver 2") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2") { + withTempDir { tempDir => + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array(StructField("groupKey", IntegerType, nullable = false))) + val valueSchema = StructType(Array( + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("SPARK-54388: composite key aggregation state ver 1") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1") { + withTempDir { tempDir => + runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("fruit", org.apache.spark.sql.types.StringType, nullable = true) + )) + // State version 1 includes key columns in the value + val valueSchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("fruit", org.apache.spark.sql.types.StringType, nullable = true), + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("SPARK-54388: composite key aggregation state ver 2") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2") { + withTempDir { tempDir => + runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("fruit", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("SPARK-54388: dropDuplicates validation") { + withTempDir { tempDir => + runDropDuplicatesQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("value", IntegerType, nullable = false), + StructField("eventTime", org.apache.spark.sql.types.TimestampType) + )) + val valueSchema = StructType(Array( + StructField("__dummy__", NullType, nullable = true) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + + testWithChangelogConfig("SPARK-54388: dropDuplicates with column specified") { + withTempDir { tempDir => + runDropDuplicatesQueryWithColumnSpecified(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("col1", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("__dummy__", NullType, nullable = true) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + + testWithChangelogConfig("SPARK-54388: dropDuplicatesWithinWatermark") { + withTempDir { tempDir => + runDropDuplicatesWithinWatermarkQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("_1", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("expiresAtMicros", LongType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + + testWithChangelogConfig("SPARK-54388: session window aggregation") { + withTempDir { tempDir => + runSessionWindowAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("sessionId", org.apache.spark.sql.types.StringType, nullable = false), + StructField("sessionStartTime", + org.apache.spark.sql.types.TimestampType, nullable = false) + )) + val valueSchema = StructType(Array( + StructField("session_window", org.apache.spark.sql.types.StructType(Array( + StructField("start", org.apache.spark.sql.types.TimestampType), + StructField("end", org.apache.spark.sql.types.TimestampType) + )), nullable = false), + StructField("sessionId", org.apache.spark.sql.types.StringType, nullable = false), + StructField("count", LongType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + + testWithChangelogConfig("SPARK-54388: flatMapGroupsWithState, state ver 1") { + // Skip this test on big endian platforms + assume(java.nio.ByteOrder.nativeOrder().equals(java.nio.ByteOrder.LITTLE_ENDIAN)) + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "1") { + withTempDir { tempDir => + assume(ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN)) + runFlatMapGroupsWithStateQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("value", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("numEvents", IntegerType, nullable = false), + StructField("startTimestampMs", LongType, nullable = false), + StructField("endTimestampMs", LongType, nullable = false), + StructField("timeoutTimestamp", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData( + normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("SPARK-54388: flatMapGroupsWithState, state ver 2") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "2") { + withTempDir { tempDir => + runFlatMapGroupsWithStateQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("value", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("groupState", org.apache.spark.sql.types.StructType(Array( + StructField("numEvents", IntegerType, nullable = false), + StructField("startTimestampMs", LongType, nullable = false), + StructField("endTimestampMs", LongType, nullable = false) + )), nullable = false), + StructField("timeoutTimestamp", LongType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData( + normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + } + + def testStreamStreamJoin(stateVersion: Int): Unit = { + withSQLConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> stateVersion.toString) { + withTempDir { tempDir => + runStreamStreamJoinQuery(tempDir.getAbsolutePath) + + Seq("right-keyToNumValues", "left-keyToNumValues").foreach(storeName => { + val stateReaderForRight = getNormalReadDf( + tempDir.getAbsolutePath, Option(storeName)) + val stateBytesDfForRight = getBytesReadDf( + tempDir.getAbsolutePath, Option(storeName)) + + val keyToNumValuesKeySchema = StructType(Array( + StructField("key", IntegerType) + )) + val keyToNumValueValueSchema = StructType(Array( + StructField("value", LongType) + )) + + validateBytesReadDfSchema(stateBytesDfForRight) + compareNormalAndBytesData( + stateReaderForRight.collect(), + stateBytesDfForRight.collect(), + StateStore.DEFAULT_COL_FAMILY_NAME, + keyToNumValuesKeySchema, + keyToNumValueValueSchema) + }) + + Seq("right-keyWithIndexToValue", "left-keyWithIndexToValue").foreach(storeName => { + val stateReaderForRight = getNormalReadDf( + tempDir.getAbsolutePath, Option(storeName)) + val stateBytesDfForRight = getBytesReadDf( + tempDir.getAbsolutePath, Option(storeName)) + + val keyToNumValuesKeySchema = StructType(Array( + StructField("key", IntegerType, nullable = false), + StructField("index", LongType) + )) + val keyToNumValueValueSchema = if (stateVersion == 2) { + StructType(Array( + StructField("value", IntegerType, nullable = false), + StructField("time", TimestampType, nullable = false), + StructField("matched", BooleanType) + )) + } else { + StructType(Array( + StructField("value", IntegerType, nullable = false), + StructField("time", TimestampType, nullable = false) + )) + } + + validateBytesReadDfSchema(stateBytesDfForRight) + compareNormalAndBytesData( + stateReaderForRight.collect(), + stateBytesDfForRight.collect(), + StateStore.DEFAULT_COL_FAMILY_NAME, + keyToNumValuesKeySchema, + keyToNumValueValueSchema) + }) + } + } + } + + testWithChangelogConfig("stream-stream join, state ver 1") { + testStreamStreamJoin(1) + } + + testWithChangelogConfig("stream-stream join, state ver 2") { + testStreamStreamJoin(2) + } + } // End of foreach loop for changelog checkpointing dimension + + test("internalOnlyReadAllColumnFamilies should fail with HDFS-backed state store") { + withTempDir { tempDir => + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[HDFSBackedStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF() + .selectExpr("value", "value % 10 AS groupKey") + .groupBy($"groupKey") + .agg( + count("*").as("cnt"), + sum("value").as("sum") + ) + .as[(Int, Long, Long)] + + testStream(aggregated, OutputMode.Update)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 0 until 1: _*), + CheckLastBatch((0, 1, 0)), + StopStream + ) + + checkError( + exception = intercept[StateRepartitionUnsupportedProviderError] { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") + .load() + .collect() + }, + condition = "STATE_REPARTITION_INVALID_CHECKPOINT.UNSUPPORTED_PROVIDER", + parameters = Map( + "checkpointLocation" -> s".*${tempDir.getAbsolutePath}", + "provider" -> classOf[HDFSBackedStateStoreProvider].getName + ), + matchPVals = true + ) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesWriterSuite.scala new file mode 100644 index 000000000000..a720120cc05c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesWriterSuite.scala @@ -0,0 +1,661 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.util.UUID + +import scala.collection.immutable.HashMap + +import org.apache.spark.TaskContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamingQueryCheckpointMetadata} +import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateStore, StateStoreConf, StateStoreId} +import org.apache.spark.sql.execution.streaming.utils.StreamingUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.{OutputMode, Trigger} +import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, NullType, StructField, StructType, TimestampType} +import org.apache.spark.util.SerializableConfiguration + +/** + * Test suite for StatePartitionAllColumnFamiliesWriter. + * Tests the writer's ability to correctly write raw bytes read from + * StatePartitionAllColumnFamiliesReader to a state store without loading previous versions. + */ +class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase { + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key, + classOf[RocksDBStateStoreProvider].getName) + } + + /** + * Common helper method to perform round-trip test: read state bytes from source, + * write to target, and verify target matches source. + * + * @param sourceDir Source checkpoint directory + * @param targetDir Target checkpoint directory + * @param keySchema Key schema for the state store + * @param valueSchema Value schema for the state store + * @param keyStateEncoderSpec Key state encoder spec + * @param storeName Optional store name (for stream-stream join which has multiple stores) + */ + private def performRoundTripTest( + sourceDir: String, + targetDir: String, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + storeName: Option[String] = None): Unit = { + + // Step 1: Read original state using normal reader (for comparison later) + val sourceReader = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, sourceDir) + val sourceNormalData = (storeName match { + case Some(name) => sourceReader.option(StateSourceOptions.STORE_NAME, name) + case None => sourceReader + }).load() + .selectExpr("key", "value", "partition_id") + .collect() + + // Step 2: Read from source using AllColumnFamiliesReader (raw bytes) + val sourceBytesReader = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, sourceDir) + .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") + val sourceBytesData = (storeName match { + case Some(name) => sourceBytesReader.option(StateSourceOptions.STORE_NAME, name) + case None => sourceBytesReader + }).load() + + // Verify schema of raw bytes + val schema = sourceBytesData.schema + assert(schema.fieldNames === Array( + "partition_key", "key_bytes", "value_bytes", "column_family_name")) + + // Step 3: Write raw bytes to target checkpoint location + val hadoopConf = spark.sessionState.newHadoopConf() + val targetCpLocation = StreamingUtils.resolvedCheckpointLocation( + hadoopConf, targetDir) + val targetCheckpointMetadata = new StreamingQueryCheckpointMetadata( + spark, targetCpLocation) + val lastBatch = targetCheckpointMetadata.commitLog.getLatestBatchId().get + val targetOffsetSeq = targetCheckpointMetadata.offsetLog.get(lastBatch).get + targetCheckpointMetadata.offsetLog.add(lastBatch + 1, targetOffsetSeq) + + // Create column family to schema map + val columnFamilyToSchemaMap = HashMap( + StateStore.DEFAULT_COL_FAMILY_NAME -> (keySchema, valueSchema) + ) + + // Create StateSourceOptions for the target checkpoint + val targetStateSourceOptions = StateSourceOptions( + resolvedCpLocation = targetCpLocation, + batchId = lastBatch, + operatorId = 0, + storeName = storeName.getOrElse(StateStoreId.DEFAULT_STORE_NAME), + joinSide = StateSourceOptions.JoinSideValues.none, + readChangeFeed = false, + fromSnapshotOptions = None, + readChangeFeedOptions = None, + stateVarName = None, + readRegisteredTimers = false, + flattenCollectionTypes = true, + internalOnlyReadAllColumnFamilies = true + ) + + val storeConf: StateStoreConf = StateStoreConf(SQLConf.get) + val serializableHadoopConf = new SerializableConfiguration(hadoopConf) + val queryId = UUID.randomUUID() + + // Define the partition processing function + val putPartitionFunc: Iterator[Row] => Unit = partition => { + val partitionInfo = new StateStoreInputPartition( + TaskContext.getPartitionId(), queryId, targetStateSourceOptions + ) + val allCFWriter = new StatePartitionAllColumnFamiliesWriter( + storeConf, serializableHadoopConf.value, partitionInfo, + columnFamilyToSchemaMap, keyStateEncoderSpec + ) + allCFWriter.put(partition) + } + + // Write raw bytes to target using foreachPartition + sourceBytesData.foreachPartition(putPartitionFunc) + + // Commit to commitLog + val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get + targetCheckpointMetadata.commitLog.add(lastBatch + 1, latestCommit) + + // Step 4: Read from target using normal reader + val targetReader = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, targetDir) + val targetNormalData = (storeName match { + case Some(name) => targetReader.option(StateSourceOptions.STORE_NAME, name) + case None => targetReader + }).load() + .selectExpr("key", "value", "partition_id") + .collect() + + // Step 5: Verify data matches + assert(sourceNormalData.length == targetNormalData.length, + s"Row count mismatch: source=${sourceNormalData.length}, " + + s"target=${targetNormalData.length}") + + // Sort and compare row by row + val sourceSorted = sourceNormalData.sortBy(_.toString) + val targetSorted = targetNormalData.sortBy(_.toString) + + sourceSorted.zip(targetSorted).zipWithIndex.foreach { + case ((sourceRow, targetRow), idx) => + assert(sourceRow == targetRow, + s"Row mismatch at index $idx:\n" + + s" Source: $sourceRow\n" + + s" Target: $targetRow") + } + } + + /** + * Helper method to test SPARK-54420 read and write with different state format versions + * for simple aggregation (single grouping key). + * @param stateVersion The state format version (1 or 2) + */ + private def testRoundTripForAggrStateVersion(stateVersion: Int): Unit = { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString, + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + withTempDir { sourceDir => + withTempDir { targetDir => + // Step 1: Create state by running a streaming aggregation + runLargeDataStreamingAggregationQuery(sourceDir.getAbsolutePath) + val inputData: MemoryStream[Int] = MemoryStream[Int] + val aggregated = getLargeDataStreamingAggregationQuery(inputData) + + // add dummy data to target source to test writer won't load previous store + testStream(aggregated, OutputMode.Update)( + StartStream(checkpointLocation = targetDir.getAbsolutePath), + // batch 0 + AddData(inputData, 0 until 2: _*), + CheckLastBatch( + (0, 1, 0, 0, 0), // 0 + (1, 1, 1, 1, 1) // 1 + ), + // batch 1 + AddData(inputData, 0 until 2: _*), + CheckLastBatch( + (0, 2, 0, 0, 0), // 0 + (1, 2, 2, 1, 1) // 1 + ) + ) + + // Step 2: Define schemas based on state version + val keySchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false))) + val valueSchema = if (stateVersion == 1) { + // State version 1 includes key columns in the value + StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + } else { + // State version 2 excludes key columns from the value + StructType(Array( + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + } + + // Create key state encoder spec (no prefix key for simple aggregation) + val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema) + + // Perform round-trip test using common helper + performRoundTripTest( + sourceDir.getAbsolutePath, + targetDir.getAbsolutePath, + keySchema, + valueSchema, + keyStateEncoderSpec + ) + } + } + } + } + + /** + * Helper method to test SPARK-54420 read and write with different state format versions + * for composite key aggregation (multiple grouping keys). + * @param stateVersion The state format version (1 or 2) + */ + private def testCompositeKeyRoundTripForStateVersion(stateVersion: Int): Unit = { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString, + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + withTempDir { sourceDir => + withTempDir { targetDir => + // Step 1: Create state by running a composite key streaming aggregation + runCompositeKeyStreamingAggregationQuery(sourceDir.getAbsolutePath) + val inputData: MemoryStream[Int] = MemoryStream[Int] + val aggregated = getCompositeKeyStreamingAggregationQuery(inputData) + + // add dummy data to target source to test writer won't load previous store + testStream(aggregated, OutputMode.Update)( + StartStream(checkpointLocation = targetDir.getAbsolutePath), + // batch 0 + AddData(inputData, 0, 1), + CheckLastBatch( + (0, "Apple", 1, 0, 0, 0), + (1, "Banana", 1, 1, 1, 1) + ) + ) + + // Step 2: Define schemas based on state version for composite key + val keySchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("fruit", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = if (stateVersion == 1) { + // State version 1 includes key columns in the value + StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("fruit", org.apache.spark.sql.types.StringType, nullable = true), + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + } else { + // State version 2 excludes key columns from the value + StructType(Array( + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + } + + // Create key state encoder spec (no prefix key for composite key aggregation) + val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema) + + // Perform round-trip test using common helper + performRoundTripTest( + sourceDir.getAbsolutePath, + targetDir.getAbsolutePath, + keySchema, + valueSchema, + keyStateEncoderSpec + ) + } + } + } + } + + /** + * Helper method to test round-trip for stream-stream join with different versions. + */ + private def testStreamStreamJoinRoundTrip(stateVersion: Int): Unit = { + withSQLConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> stateVersion.toString) { + withTempDir { sourceDir => + withTempDir { targetDir => + // Step 1: Create state by running stream-stream join + runStreamStreamJoinQuery(sourceDir.getAbsolutePath) + + // Create dummy data in target + val inputData: MemoryStream[(Int, Long)] = MemoryStream[(Int, Long)] + val query = getStreamStreamJoinQuery(inputData) + testStream(query)( + StartStream(checkpointLocation = targetDir.getAbsolutePath), + AddData(inputData, (1, 1L)), + CheckNewAnswer() + ) + + // Step 2: Test all 4 state stores created by stream-stream join + // Test keyToNumValues stores (both left and right) + Seq("left-keyToNumValues", "right-keyToNumValues").foreach { storeName => + val keySchema = StructType(Array( + StructField("key", IntegerType) + )) + val valueSchema = StructType(Array( + StructField("value", LongType) + )) + val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema) + + // Perform round-trip test using common helper + performRoundTripTest( + sourceDir.getAbsolutePath, + targetDir.getAbsolutePath, + keySchema, + valueSchema, + keyStateEncoderSpec, + storeName = Some(storeName) + ) + } + + // Test keyWithIndexToValue stores (both left and right) + Seq("left-keyWithIndexToValue", "right-keyWithIndexToValue").foreach { storeName => + val keySchema = StructType(Array( + StructField("key", IntegerType, nullable = false), + StructField("index", LongType) + )) + val valueSchema = if (stateVersion == 2) { + StructType(Array( + StructField("value", IntegerType, nullable = false), + StructField("time", TimestampType, nullable = false), + StructField("matched", BooleanType) + )) + } else { + StructType(Array( + StructField("value", IntegerType, nullable = false), + StructField("time", TimestampType, nullable = false) + )) + } + val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema) + + // Perform round-trip test using common helper + performRoundTripTest( + sourceDir.getAbsolutePath, + targetDir.getAbsolutePath, + keySchema, + valueSchema, + keyStateEncoderSpec, + storeName = Some(storeName) + ) + } + } + } + } + } + + /** + * Helper method to test round-trip for flatMapGroupsWithState with different versions. + */ + private def testFlatMapGroupsWithStateRoundTrip(stateVersion: Int): Unit = { + // Skip this test on big endian platforms (version 1 only) + if (stateVersion == 1) { + assume(java.nio.ByteOrder.nativeOrder().equals(java.nio.ByteOrder.LITTLE_ENDIAN)) + } + + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> stateVersion.toString) { + withTempDir { sourceDir => + withTempDir { targetDir => + // Step 1: Create state by running flatMapGroupsWithState + runFlatMapGroupsWithStateQuery(sourceDir.getAbsolutePath) + + // Create dummy data in target + val clock = new StreamManualClock + val inputData: MemoryStream[(String, Long)] = MemoryStream[(String, Long)] + val query = getFlatMapGroupsWithStateQuery(inputData) + testStream(query, OutputMode.Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock, + checkpointLocation = targetDir.getAbsolutePath), + AddData(inputData, ("a", 1L)), + AdvanceManualClock(1 * 1000), + CheckLastBatch(("a", 1, 0, false)) + ) + + // Step 2: Define schemas for flatMapGroupsWithState + val keySchema = StructType(Array( + StructField("value", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = if (stateVersion == 1) { + StructType(Array( + StructField("numEvents", IntegerType, nullable = false), + StructField("startTimestampMs", LongType, nullable = false), + StructField("endTimestampMs", LongType, nullable = false), + StructField("timeoutTimestamp", IntegerType, nullable = false) + )) + } else { + StructType(Array( + StructField("groupState", org.apache.spark.sql.types.StructType(Array( + StructField("numEvents", IntegerType, nullable = false), + StructField("startTimestampMs", LongType, nullable = false), + StructField("endTimestampMs", LongType, nullable = false) + )), nullable = false), + StructField("timeoutTimestamp", LongType, nullable = false) + )) + } + val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema) + + // Perform round-trip test using common helper + performRoundTripTest( + sourceDir.getAbsolutePath, + targetDir.getAbsolutePath, + keySchema, + valueSchema, + keyStateEncoderSpec + ) + } + } + } + } + + // Run all tests with both changelog checkpointing enabled and disabled + Seq(true, false).foreach { changelogCheckpointingEnabled => + val testSuffix = if (changelogCheckpointingEnabled) { + "with changelog checkpointing" + } else { + "without changelog checkpointing" + } + + def testWithChangelogConfig(testName: String)(testFun: => Unit): Unit = { + test(s"$testName ($testSuffix)") { + withSQLConf( + "spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" -> + changelogCheckpointingEnabled.toString) { + testFun + } + } + } + + testWithChangelogConfig("SPARK-54420: aggregation state ver 1") { + testRoundTripForAggrStateVersion(1) + } + + testWithChangelogConfig("SPARK-54420: aggregation state ver 2") { + testRoundTripForAggrStateVersion(2) + } + + testWithChangelogConfig("SPARK-54420: composite key aggregation state ver 1") { + testCompositeKeyRoundTripForStateVersion(1) + } + + testWithChangelogConfig("SPARK-54420: composite key aggregation state ver 2") { + testCompositeKeyRoundTripForStateVersion(2) + } + + testWithChangelogConfig("SPARK-54420: dropDuplicatesWithinWatermark") { + withTempDir { sourceDir => + withTempDir { targetDir => + // Step 1: Create state by running dropDuplicatesWithinWatermark + runDropDuplicatesWithinWatermarkQuery(sourceDir.getAbsolutePath) + + // Create dummy data in target + val inputData: MemoryStream[(String, Int)] = MemoryStream[(String, Int)] + val deduped = getDropDuplicatesWithinWatermarkQuery(inputData) + testStream(deduped, OutputMode.Append)( + StartStream(checkpointLocation = targetDir.getAbsolutePath), + AddData(inputData, ("a", 1)), + CheckAnswer(("a", 1)) + ) + + // Step 2: Define schemas for dropDuplicatesWithinWatermark + val keySchema = StructType(Array( + StructField("_1", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("expiresAtMicros", LongType, nullable = false) + )) + val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema) + + // Perform round-trip test using common helper + performRoundTripTest( + sourceDir.getAbsolutePath, + targetDir.getAbsolutePath, + keySchema, + valueSchema, + keyStateEncoderSpec + ) + } + } + } + + testWithChangelogConfig("SPARK-54420: dropDuplicates with column specified") { + withTempDir { sourceDir => + withTempDir { targetDir => + // Step 1: Create state by running dropDuplicates with column + runDropDuplicatesQueryWithColumnSpecified(sourceDir.getAbsolutePath) + + // Create dummy data in target + val inputData: MemoryStream[(String, Int)] = MemoryStream[(String, Int)] + val deduped = getDropDuplicatesQueryWithColumnSpecified(inputData) + testStream(deduped, OutputMode.Append)( + StartStream(checkpointLocation = targetDir.getAbsolutePath), + AddData(inputData, ("a", 1)), + CheckAnswer(("a", 1)) + ) + + // Step 2: Define schemas for dropDuplicates with column specified + val keySchema = StructType(Array( + StructField("col1", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("__dummy__", NullType, nullable = true) + )) + val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema) + + // Perform round-trip test using common helper + performRoundTripTest( + sourceDir.getAbsolutePath, + targetDir.getAbsolutePath, + keySchema, + valueSchema, + keyStateEncoderSpec + ) + } + } + } + + testWithChangelogConfig("SPARK-54420: session window aggregation") { + withTempDir { sourceDir => + withTempDir { targetDir => + // Step 1: Create state by running session window aggregation + runSessionWindowAggregationQuery(sourceDir.getAbsolutePath) + + // Create dummy data in target + val inputData: MemoryStream[(String, Long)] = MemoryStream[(String, Long)] + val aggregated = getSessionWindowAggregationQuery(inputData) + testStream(aggregated, OutputMode.Complete())( + StartStream(checkpointLocation = targetDir.getAbsolutePath), + AddData(inputData, ("a", 40L)), + CheckNewAnswer( + ("a", 40, 50, 10, 1) + ), + StopStream + ) + + // Step 2: Define schemas for session window aggregation + val keySchema = StructType(Array( + StructField("sessionId", org.apache.spark.sql.types.StringType, nullable = false), + StructField("sessionStartTime", + org.apache.spark.sql.types.TimestampType, nullable = false) + )) + val valueSchema = StructType(Array( + StructField("session_window", org.apache.spark.sql.types.StructType(Array( + StructField("start", org.apache.spark.sql.types.TimestampType), + StructField("end", org.apache.spark.sql.types.TimestampType) + )), nullable = false), + StructField("sessionId", org.apache.spark.sql.types.StringType, nullable = false), + StructField("count", LongType, nullable = false) + )) + // Session window aggregation uses prefix key scanning where sessionId is the prefix + val keyStateEncoderSpec = PrefixKeyScanStateEncoderSpec(keySchema, 1) + + // Perform round-trip test using common helper + performRoundTripTest( + sourceDir.getAbsolutePath, + targetDir.getAbsolutePath, + keySchema, + valueSchema, + keyStateEncoderSpec + ) + } + } + } + + testWithChangelogConfig("SPARK-54420: dropDuplicates") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2", + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + withTempDir { sourceDir => + withTempDir { targetDir => + + // Step 1: Create state by running a streaming aggregation + runDropDuplicatesQuery(sourceDir.getAbsolutePath) + val inputData: MemoryStream[Int] = MemoryStream[Int] + val stream = getDropDuplicatesQuery(inputData) + testStream(stream, OutputMode.Append)( + StartStream(checkpointLocation = targetDir.getAbsolutePath), + AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*), + CheckAnswer(10 to 15: _*), + assertNumStateRows(total = 6, updated = 6) + ) + + // Step 2: Define schemas for dropDuplicates (state version 2) + val keySchema = StructType(Array( + StructField("value", IntegerType, nullable = false), + StructField("eventTime", org.apache.spark.sql.types.TimestampType) + )) + val valueSchema = StructType(Array( + StructField("__dummy__", NullType) + )) + val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema) + + // Perform round-trip test using common helper + performRoundTripTest( + sourceDir.getAbsolutePath, + targetDir.getAbsolutePath, + keySchema, + valueSchema, + keyStateEncoderSpec + ) + } + } + } + } + + testWithChangelogConfig("SPARK-54420: flatMapGroupsWithState state ver 1") { + testFlatMapGroupsWithStateRoundTrip(1) + } + + testWithChangelogConfig("SPARK-54420: flatMapGroupsWithState state ver 2") { + testFlatMapGroupsWithStateRoundTrip(2) + } + + testWithChangelogConfig("SPARK-54420: stream-stream join state ver 1") { + testStreamStreamJoinRoundTrip(1) + } + + testWithChangelogConfig("SPARK-54420: stream-stream join state ver 2") { + testStreamStreamJoinRoundTrip(2) + } + } // End of foreach loop for changelog checkpointing dimension +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index 2afb3e69c4e4..5f1adadc30a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -200,8 +200,10 @@ class CkptIdCollectingStateStoreProviderWrapper extends StateStoreProvider { override def getStore( version: Long, stateStoreCkptId: Option[String] = None, - forceSnapshotOnCommit: Boolean = false): StateStore = { - val innerStateStore = innerProvider.getStore(version, stateStoreCkptId, forceSnapshotOnCommit) + forceSnapshotOnCommit: Boolean = false, + loadEmpty: Boolean = false): StateStore = { + val innerStateStore = innerProvider.getStore(version, stateStoreCkptId, + forceSnapshotOnCommit, loadEmpty) CkptIdCollectingStateStoreWrapper(innerStateStore) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index ca72f7033118..45a19ef9734a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -3942,6 +3942,84 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession }} } + testWithStateStoreCheckpointIdsAndChangelogEnabled( + "SPARK-54420: loadEmpty creates empty store at specified version") { + enableStateStoreCheckpointIds => + val remoteDir = Utils.createTempDir().toString + new File(remoteDir).delete() + val versionToUniqueId = new mutable.HashMap[Long, String]() + + withDB(remoteDir, + enableStateStoreCheckpointIds = enableStateStoreCheckpointIds, + versionToUniqueId = versionToUniqueId) { db => + // Put initial data first + val version = 0 + db.load(version, versionToUniqueId.get(0)) + db.put("a", "1") + val (version1, _) = db.commit() + assert(db.get("a") === "1") + + db.loadEmpty(version1, versionToUniqueId.get(1)) + + // Add data and commit - should produce version 11 + db.put("b", "2") + val (version2, _) = db.commit(forceSnapshot = true) + assert(version2 === version1 + 1) + assert(toStr(db.get("b")) === "2") + assert(db.get("a") === null) + assert(iterator(db).isEmpty) + + db.put("c", "3") + assert(toStr(db.get("b")) === "2") + assert(toStr(db.get("c")) === "3") + val (version3, _) = db.commit(forceSnapshot = true) + assert(version3 === version2 + 1) + } + + // Verify we can reload the committed version + withDB(remoteDir, version = 3, + enableStateStoreCheckpointIds = enableStateStoreCheckpointIds, + versionToUniqueId = versionToUniqueId) { db => + assert(toStr(db.get("c")) === "3") + assert(db.iterator().map(toStr).toSet === Set(("a", "1"))) + } + } + + testWithStateStoreCheckpointIdsAndChangelogEnabled("SPARK-54420: loadEmpty at version 0") { + enableStateStoreCheckpointIds => + val remoteDir = Utils.createTempDir().toString + new File(remoteDir).delete() + val versionToUniqueId = new mutable.HashMap[Long, String]() + + withDB(remoteDir, + enableStateStoreCheckpointIds = enableStateStoreCheckpointIds, + versionToUniqueId = versionToUniqueId) { db => + // Create empty store at version 0 + val ckptId = if (enableStateStoreCheckpointIds) { + Some(java.util.UUID.randomUUID.toString) + } else { + None + } + + db.loadEmpty(0, ckptId) + + // Verify store is empty + assert(iterator(db).isEmpty) + + // Add data and commit - should produce version 1 + db.put("a", "1") + val (newVersion, _) = db.commit(true) + assert(newVersion === 1) + } + + // Verify we can reload version 1 + withDB(remoteDir, version = 1, + enableStateStoreCheckpointIds = enableStateStoreCheckpointIds, + versionToUniqueId = versionToUniqueId) { db => + assert(toStr(db.get("a")) === "1") + } + } + test("SPARK-44639: Use Java tmp dir instead of configured local dirs on Yarn") { val conf = new Configuration() conf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index a997ead74097..5dc067a06a52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -95,7 +95,8 @@ class SignalingStateStoreProvider extends StateStoreProvider with Logging { override def getStore( version: Long, uniqueId: Option[String], - forceSnapshotOnCommit: Boolean = false): StateStore = null + forceSnapshotOnCommit: Boolean = false, + loadEmpty: Boolean = false): StateStore = null /** * Simulates a maintenance operation that blocks until a signal is received. @@ -175,7 +176,8 @@ class FakeStateStoreProviderTracksCloseThread extends StateStoreProvider { override def getStore( version: Long, uniqueId: Option[String], - forceSnapshotOnCommit: Boolean = false): StateStore = null + forceSnapshotOnCommit: Boolean = false, + loadEmpty: Boolean = false): StateStore = null override def doMaintenance(): Unit = {} } @@ -247,7 +249,8 @@ class FakeStateStoreProviderWithMaintenanceError extends StateStoreProvider { override def getStore( version: Long, uniqueId: Option[String], - forceSnapshotOnCommit: Boolean = false): StateStore = null + forceSnapshotOnCommit: Boolean = false, + loadEmpty: Boolean = false): StateStore = null override def doMaintenance(): Unit = { Thread.currentThread.setUncaughtExceptionHandler(exceptionHandler) @@ -1438,6 +1441,25 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] "HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1")) } + test("SPARK-54420: HDFSBackedStateStoreProvider does not support load empty store") { + val provider = new HDFSBackedStateStoreProvider() + provider.init( + StateStoreId(newDir(), Random.nextInt(), 0), + keySchema, + valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + useColumnFamilies = false, + new StateStoreConf(), + new Configuration()) + + val e = intercept[StateStoreUnsupportedOperationException] { + provider.getStore(0, loadEmpty = true) + } + assert(e.getMessage.contains( + "getStore operation not supported with loadEmpty parameter is not supported " + + "in HDFSBackedStateStoreProvider")) + } + test("Auto snapshot repair") { withSQLConf( SQLConf.STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED.key -> false.toString, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index cc1d0bc1ed17..b4fd41a6b550 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -1463,7 +1463,8 @@ class TestStateStoreProvider extends StateStoreProvider { override def getStore( version: Long, stateStoreCkptId: Option[String] = None, - forceSnapshotOnCommit: Boolean = false): StateStore = null + forceSnapshotOnCommit: Boolean = false, + loadEmpty: Boolean = false): StateStore = null } /** A fake source that throws `ThrowingExceptionInCreateSource.exception` in `createSource` */