Skip to content
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5493,6 +5493,11 @@
"message" : [
"Unsupported offset sequence version <version>. Please make sure the checkpoint is from a supported Spark version (Spark 4.0+)."
]
},
"UNSUPPORTED_PROVIDER" : {
"message" : [
"<provider> is not supported"
]
}
},
"sqlState" : "55019"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.transformwith
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils
import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata
import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.execution.streaming.state.OfflineStateRepartitionErrors
import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.streaming.TimeMode
Expand All @@ -66,6 +67,14 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
StateSourceOptions.apply(session, hadoopConf, properties))
val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId)
// We only support RocksDB because the repartition work that this option
// is built for only supports RocksDB
if (sourceOptions.internalOnlyReadAllColumnFamilies
&& stateConf.providerClass != classOf[RocksDBStateStoreProvider].getName) {
throw OfflineStateRepartitionErrors.unsupportedStateStoreProviderError(
sourceOptions.resolvedCpLocation,
stateConf.providerClass)
}
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
sourceOptions)

Expand Down Expand Up @@ -372,15 +381,17 @@ case class StateSourceOptions(
stateVarName: Option[String],
readRegisteredTimers: Boolean,
flattenCollectionTypes: Boolean,
internalOnlyReadAllColumnFamilies: Boolean = false,
startOperatorStateUniqueIds: Option[Array[Array[String]]] = None,
endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)

override def toString: String = {
var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " +
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " +
s"stateVarName=${stateVarName.getOrElse("None")}, +" +
s"flattenCollectionTypes=$flattenCollectionTypes"
s"stateVarName=${stateVarName.getOrElse("None")}, " +
s"flattenCollectionTypes=$flattenCollectionTypes, " +
s"internalOnlyReadAllColumnFamilies=$internalOnlyReadAllColumnFamilies"
if (fromSnapshotOptions.isDefined) {
desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}"
desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}"
Expand All @@ -407,6 +418,7 @@ object StateSourceOptions extends DataSourceOptions {
val STATE_VAR_NAME = newOption("stateVarName")
val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers")
val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes")
val INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = newOption("internalOnlyReadAllColumnFamilies")

object JoinSideValues extends Enumeration {
type JoinSideValues = Value
Expand Down Expand Up @@ -478,6 +490,7 @@ object StateSourceOptions extends DataSourceOptions {
s"Valid values are ${JoinSideValues.values.mkString(",")}")
}

// Use storeName rather than joinSide to identify the specific join store
if (joinSide != JoinSideValues.none && storeName != StateStoreId.DEFAULT_STORE_NAME) {
throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME))
}
Expand All @@ -492,6 +505,29 @@ object StateSourceOptions extends DataSourceOptions {

val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean)

val internalOnlyReadAllColumnFamilies = try {
Option(options.get(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES)).exists(_.toBoolean)
} catch {
case _: IllegalArgumentException =>
throw StateDataSourceErrors.invalidOptionValue(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
"Boolean value is expected")
}

if (internalOnlyReadAllColumnFamilies && stateVarName.isDefined) {
throw StateDataSourceErrors.conflictOptions(
Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, STATE_VAR_NAME))
}

if (internalOnlyReadAllColumnFamilies && joinSide != JoinSideValues.none) {
throw StateDataSourceErrors.conflictOptions(
Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, JOIN_SIDE))
}

if (internalOnlyReadAllColumnFamilies && readChangeFeed) {
throw StateDataSourceErrors.conflictOptions(
Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, READ_CHANGE_FEED))
}

val changeStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong)
var changeEndBatchId = Option(options.get(CHANGE_END_BATCH_ID)).map(_.toLong)

Expand Down Expand Up @@ -615,7 +651,7 @@ object StateSourceOptions extends DataSourceOptions {
StateSourceOptions(
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
stateVarName, readRegisteredTimers, flattenCollectionTypes,
stateVarName, readRegisteredTimers, flattenCollectionTypes, internalOnlyReadAllColumnFamilies,
startOperatorStateUniqueIds, endOperatorStateUniqueIds)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ class StatePartitionReaderFactory(

override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition]
if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) {
new StatePartitionAllColumnFamiliesReader(storeConf, hadoopConf,
stateStoreInputPartition, schema, keyStateEncoderSpec, stateStoreColFamilySchemaOpt)
} else if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt,
stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt)
Expand Down Expand Up @@ -81,16 +84,22 @@ abstract class StatePartitionReaderBase(
private val schemaForValueRow: StructType =
StructType(Array(StructField("__dummy__", NullType)))

protected val keySchema = {
protected val keySchema : StructType = {
if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) {
SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions)
} else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) {
require(stateStoreColFamilySchemaOpt.isDefined)
stateStoreColFamilySchemaOpt.map(_.keySchema).get
} else {
SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
}
}

protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
protected val valueSchema : StructType = if (stateVariableInfoOpt.isDefined) {
schemaForValueRow
} else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) {
require(stateStoreColFamilySchemaOpt.isDefined)
stateStoreColFamilySchemaOpt.map(_.valueSchema).get
} else {
SchemaUtil.getSchemaAsDataType(
schema, "value").asInstanceOf[StructType]
Expand Down Expand Up @@ -237,6 +246,48 @@ class StatePartitionReader(
}
}

/**
* An implementation of [[StatePartitionReaderBase]] for reading all column families
* in binary format. This reader returns raw key and value bytes along with column family names.
* We are returning key/value bytes because each column family can have different schema
* It will also return the partition key
*/
class StatePartitionAllColumnFamiliesReader(
storeConf: StateStoreConf,
hadoopConf: SerializableConfiguration,
partition: StateStoreInputPartition,
schema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
extends StatePartitionReaderBase(
storeConf,
hadoopConf, partition, schema,
keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {

private lazy val store: ReadStateStore = {
assert(getStartStoreUniqueId == getEndStoreUniqueId,
"Start and end store unique IDs must be the same when reading all column families")
provider.getReadStore(
partition.sourceOptions.batchId + 1,
getStartStoreUniqueId
)
}

override lazy val iter: Iterator[InternalRow] = {
store
.iterator()
.map { pair =>
SchemaUtil.unifyStateRowPairAsRawBytes(
(pair.key, pair.value), StateStore.DEFAULT_COL_FAMILY_NAME)
}
}

override def close(): Unit = {
store.release()
super.close()
}
}

/**
* An implementation of [[StatePartitionReaderBase]] for the readChangeFeed mode of State Data
* Source. It reads the change of state over batches of a particular partition.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.state

import java.io.IOException

import scala.collection.MapView
import scala.collection.immutable.HashMap

import org.apache.hadoop.conf.Configuration

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, StateStore, StateStoreConf, StateStoreId, StateStoreProvider, StateStoreProviderId}
import org.apache.spark.sql.types.StructType

/**
* A writer that can directly write binary data to the streaming state store.
*
* This writer expects input rows with the same schema produced by
* StatePartitionAllColumnFamiliesReader:
* (partition_key, key_bytes, value_bytes, column_family_name)
*
* The writer creates a fresh (empty) state store instance for the target commit version
* instead of loading previous partition data. After writing all rows for the partition, it will
* commit all changes as a snapshot
*/
class StatePartitionAllColumnFamiliesWriter(
storeConf: StateStoreConf,
hadoopConf: Configuration,
partition: StateStoreInputPartition,
columnFamilyToSchemaMap: HashMap[String, (StructType, StructType)],
keyStateEncoderSpec: KeyStateEncoderSpec) {

private val (defaultKeySchema, defaultValueSchema) = {
columnFamilyToSchemaMap.getOrElse(
StateStore.DEFAULT_COL_FAMILY_NAME,
throw new IllegalArgumentException(
s"Column family ${StateStore.DEFAULT_COL_FAMILY_NAME} not found in schema map")
)
}

private val columnFamilyToKeySchemaLenMap: MapView[String, Int] =
columnFamilyToSchemaMap.view.mapValues(_._1.length)
private val columnFamilyToValueSchemaLenMap: MapView[String, Int] =
columnFamilyToSchemaMap.view.mapValues(_._2.length)

private val rowConverter = {
val schema = SchemaUtil.getSourceSchema(
partition.sourceOptions, defaultKeySchema, defaultValueSchema, None, None)
CatalystTypeConverters.createToCatalystConverter(schema)
}

protected lazy val provider: StateStoreProvider = {
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)

val provider = StateStoreProvider.createAndInit(
stateStoreProviderId, defaultKeySchema, defaultValueSchema, keyStateEncoderSpec,
useColumnFamilies = false, storeConf, hadoopConf,
useMultipleValuesPerKey = false, stateSchemaProvider = None)
provider
}

private lazy val stateStore: StateStore = {
val stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) {
Some(java.util.UUID.randomUUID.toString)
} else {
None
}
val version = partition.sourceOptions.batchId + 1
// Create empty store to avoid loading old partition data during repartitioning
// Use loadEmpty=true to create a fresh state store without loading previous versions
// We create the empty store AT version, and the next commit will
// produce version + 1
provider.getStore(
version,
stateStoreCkptId,
forceSnapshotOnCommit = true,
loadEmpty = true
)
}

// The function that writes and commits data to state store. It takes in rows with schema
// - partition_key, StructType
// - key_bytes, BinaryType
// - value_bytes, BinaryType
// - column_family_name, StringType
def put(partition: Iterator[Row]): Unit = {
partition.foreach(row => putRaw(row))
stateStore.commit()
}

private def putRaw(rawRecord: Row): Unit = {
val record = rowConverter(rawRecord).asInstanceOf[InternalRow]
// Validate record schema
if (record.numFields != 4) {
throw new IOException(
s"Invalid record schema: expected 4 fields (partition_key, key_bytes, value_bytes, " +
s"column_family_name), got ${record.numFields}")
}

// Extract raw bytes and column family name from the record
val keyBytes = record.getBinary(1)
val valueBytes = record.getBinary(2)
val colFamilyName = record.getString(3)

// Reconstruct UnsafeRow objects from the raw bytes
// The bytes are in UnsafeRow memory format from StatePartitionReaderAllColumnFamilies
val keyRow = new UnsafeRow(columnFamilyToKeySchemaLenMap(colFamilyName))
keyRow.pointTo(keyBytes, keyBytes.length)

val valueRow = new UnsafeRow(columnFamilyToValueSchemaLenMap(colFamilyName))
valueRow.pointTo(valueBytes, valueBytes.length)

// Use StateStore API which handles proper RocksDB encoding (version byte, checksums, etc.)
stateStore.put(keyRow, valueRow, colFamilyName)
}
}
Loading