Skip to content

Commit ae24efe

Browse files
ericm-dbhaoyangeng-db
authored andcommitted
[SPARK-51596][SS] Fix concurrent StateStoreProvider maintenance and closing
### What changes were proposed in this pull request? Moves the unload operation away from task thread into the maintenance thread. To ensure unloading still occurs ASAP (rather than potentially waiting for the maintenance interval) as was introduced by https://issues.apache.org/jira/browse/SPARK-33827, we immediately trigger a maintenance thread to do the unload. This gives us an extra benefit that unloading other providers doesn't block the task thread. To capitalize on this, unload() should not hold the loadedProviders lock the entire time (which will block other task threads), but instead release it once it has deleted the unloading providers from the map and close the providers without the lock held. ### Why are the changes needed? Currently, both the task thread and maintenance thread can call unload() on a provider. This leads to a race condition where the maintenance could be conducting maintenance while the task thread is closing the provider, leading to unexpected behavior. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Added unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#51565 from ericm-db/maint-changes. Authored-by: Eric Marnadi <[email protected]> Signed-off-by: Anish Shrigondekar <[email protected]>
1 parent 0228de7 commit ae24efe

File tree

4 files changed

+607
-53
lines changed

4 files changed

+607
-53
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,6 +2436,13 @@ object SQLConf {
24362436
.timeConf(TimeUnit.SECONDS)
24372437
.createWithDefault(300L)
24382438

2439+
val STATE_STORE_MAINTENANCE_PROCESSING_TIMEOUT =
2440+
buildConf("spark.sql.streaming.stateStore.maintenanceProcessingTimeout")
2441+
.internal()
2442+
.doc("Timeout in seconds to wait for maintenance to process this partition.")
2443+
.timeConf(TimeUnit.SECONDS)
2444+
.createWithDefault(30L)
2445+
24392446
val STATE_SCHEMA_CHECK_ENABLED =
24402447
buildConf("spark.sql.streaming.stateStore.stateSchemaCheck")
24412448
.doc("When true, Spark will validate the state schema against schema on existing state and " +
@@ -6343,6 +6350,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
63436350

63446351
def stateStoreMaintenanceShutdownTimeout: Long = getConf(STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT)
63456352

6353+
def stateStoreMaintenanceProcessingTimeout: Long =
6354+
getConf(STATE_STORE_MAINTENANCE_PROCESSING_TIMEOUT)
6355+
63466356
def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
63476357

63486358
def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED)

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala

Lines changed: 208 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
package org.apache.spark.sql.execution.streaming.state
1919

2020
import java.util.UUID
21-
import java.util.concurrent.{ScheduledFuture, TimeUnit}
21+
import java.util.concurrent.{ConcurrentLinkedQueue, ScheduledFuture, TimeUnit}
2222
import javax.annotation.concurrent.GuardedBy
2323

2424
import scala.collection.mutable
25+
import scala.collection.mutable.ArrayBuffer
2526
import scala.util.control.NonFatal
2627

2728
import org.apache.hadoop.conf.Configuration
@@ -31,13 +32,14 @@ import org.json4s.JsonAST.JValue
3132
import org.json4s.JsonDSL._
3233
import org.json4s.jackson.JsonMethods.{compact, render}
3334

34-
import org.apache.spark.{SparkContext, SparkEnv, SparkException}
35+
import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskContext}
3536
import org.apache.spark.internal.{Logging, LogKeys, MDC}
3637
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
3738
import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
3839
import org.apache.spark.sql.errors.QueryExecutionErrors
3940
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
4041
import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, StreamExecution}
42+
import org.apache.spark.sql.execution.streaming.state.MaintenanceTaskType._
4143
import org.apache.spark.sql.types.StructType
4244
import org.apache.spark.util.{NextIterator, ThreadUtils, Utils}
4345

@@ -53,6 +55,14 @@ object StateStoreEncoding {
5355
case object Avro extends StateStoreEncoding
5456
}
5557

58+
sealed trait MaintenanceTaskType
59+
60+
object MaintenanceTaskType {
61+
case object FromUnloadedProvidersQueue extends MaintenanceTaskType
62+
case object FromTaskThread extends MaintenanceTaskType
63+
case object FromLoadedProviders extends MaintenanceTaskType
64+
}
65+
5666
/**
5767
* Base trait for a versioned key-value store which provides read operations. Each instance of a
5868
* `ReadStateStore` represents a specific version of state data, and such instances are created
@@ -554,7 +564,11 @@ trait StateStoreProvider {
554564
*/
555565
def stateStoreId: StateStoreId
556566

557-
/** Called when the provider instance is unloaded from the executor */
567+
/**
568+
* Called when the provider instance is unloaded from the executor
569+
* WARNING: IF PROVIDER FROM [[StateStore.loadedProviders]],
570+
* CLOSE MUST ONLY BE CALLED FROM MAINTENANCE THREAD!
571+
*/
558572
def close(): Unit
559573

560574
/**
@@ -843,6 +857,9 @@ object StateStore extends Logging {
843857

844858
private val maintenanceThreadPoolLock = new Object
845859

860+
private val unloadedProvidersToClose =
861+
new ConcurrentLinkedQueue[(StateStoreProviderId, StateStoreProvider)]
862+
846863
// This set is to keep track of the partitions that are queued
847864
// for maintenance or currently have maintenance running on them
848865
// to prevent the same partition from being processed concurrently.
@@ -1012,7 +1029,21 @@ object StateStore extends Logging {
10121029
if (!storeConf.unloadOnCommit) {
10131030
val otherProviderIds = loadedProviders.keys.filter(_ != storeProviderId).toSeq
10141031
val providerIdsToUnload = reportActiveStoreInstance(storeProviderId, otherProviderIds)
1015-
providerIdsToUnload.foreach(unload(_))
1032+
val taskContextIdLogLine = Option(TaskContext.get()).map { tc =>
1033+
log"taskId=${MDC(LogKeys.TASK_ID, tc.taskAttemptId())}"
1034+
}.getOrElse(log"")
1035+
providerIdsToUnload.foreach(id => {
1036+
loadedProviders.remove(id).foreach( provider => {
1037+
// Trigger maintenance thread to immediately do maintenance on and close the provider.
1038+
// Doing maintenance first allows us to do maintenance for a constantly-moving state
1039+
// store.
1040+
logInfo(log"Submitted maintenance from task thread to close " +
1041+
log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}." + taskContextIdLogLine +
1042+
log"Removed provider from loadedProviders")
1043+
submitMaintenanceWorkForProvider(
1044+
id, provider, storeConf, MaintenanceTaskType.FromTaskThread)
1045+
})
1046+
})
10161047
}
10171048

10181049
provider
@@ -1029,14 +1060,30 @@ object StateStore extends Logging {
10291060
}
10301061
}
10311062

1032-
/** Unload a state store provider */
1033-
def unload(storeProviderId: StateStoreProviderId): Unit = loadedProviders.synchronized {
1034-
loadedProviders.remove(storeProviderId).foreach(_.close())
1063+
/**
1064+
* Unload a state store provider.
1065+
* If alreadyRemovedFromLoadedProviders is None, provider will be
1066+
* removed from loadedProviders and closed.
1067+
* If alreadyRemovedFromLoadedProviders is Some, provider will be closed
1068+
* using passed in provider.
1069+
* WARNING: CAN ONLY BE CALLED FROM MAINTENANCE THREAD!
1070+
*/
1071+
def removeFromLoadedProvidersAndClose(
1072+
storeProviderId: StateStoreProviderId,
1073+
alreadyRemovedProvider: Option[StateStoreProvider] = None): Unit = {
1074+
val providerToClose = alreadyRemovedProvider.orElse {
1075+
loadedProviders.synchronized {
1076+
loadedProviders.remove(storeProviderId)
1077+
}
1078+
}
1079+
providerToClose.foreach { provider =>
1080+
provider.close()
1081+
}
10351082
}
10361083

10371084
/** Unload all state store providers: unit test purpose */
10381085
private[sql] def unloadAll(): Unit = loadedProviders.synchronized {
1039-
loadedProviders.keySet.foreach { key => unload(key) }
1086+
loadedProviders.keySet.foreach { key => removeFromLoadedProvidersAndClose(key) }
10401087
loadedProviders.clear()
10411088
}
10421089

@@ -1075,7 +1122,7 @@ object StateStore extends Logging {
10751122

10761123
/** Unload and stop all state store providers */
10771124
def stop(): Unit = loadedProviders.synchronized {
1078-
loadedProviders.keySet.foreach { key => unload(key) }
1125+
loadedProviders.keySet.foreach { key => removeFromLoadedProvidersAndClose(key) }
10791126
loadedProviders.clear()
10801127
_coordRef = null
10811128
stopMaintenanceTask()
@@ -1090,7 +1137,7 @@ object StateStore extends Logging {
10901137
if (SparkEnv.get != null && !isMaintenanceRunning && !storeConf.unloadOnCommit) {
10911138
maintenanceTask = new MaintenanceTask(
10921139
storeConf.maintenanceInterval,
1093-
task = { doMaintenance() }
1140+
task = { doMaintenance(storeConf) }
10941141
)
10951142
maintenanceThreadPool = new MaintenanceThreadPool(numMaintenanceThreads,
10961143
maintenanceShutdownTimeout)
@@ -1099,6 +1146,27 @@ object StateStore extends Logging {
10991146
}
11001147
}
11011148

1149+
// Wait until this partition can be processed
1150+
private def awaitProcessThisPartition(
1151+
id: StateStoreProviderId,
1152+
timeoutMs: Long): Boolean = maintenanceThreadPoolLock synchronized {
1153+
val startTime = System.currentTimeMillis()
1154+
val endTime = startTime + timeoutMs
1155+
1156+
// If immediate processing fails, wait with timeout
1157+
var canProcessThisPartition = processThisPartition(id)
1158+
while (!canProcessThisPartition && System.currentTimeMillis() < endTime) {
1159+
maintenanceThreadPoolLock.wait(timeoutMs)
1160+
canProcessThisPartition = processThisPartition(id)
1161+
}
1162+
val elapsedTime = System.currentTimeMillis() - startTime
1163+
logInfo(log"Waited for ${MDC(LogKeys.TOTAL_TIME, elapsedTime)} ms to be able to process " +
1164+
log"maintenance for partition ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}")
1165+
canProcessThisPartition
1166+
}
1167+
1168+
private def doMaintenance(): Unit = doMaintenance(StateStoreConf.empty)
1169+
11021170
private def processThisPartition(id: StateStoreProviderId): Boolean = {
11031171
maintenanceThreadPoolLock.synchronized {
11041172
if (!maintenancePartitions.contains(id)) {
@@ -1114,56 +1182,42 @@ object StateStore extends Logging {
11141182
* Execute background maintenance task in all the loaded store providers if they are still
11151183
* the active instances according to the coordinator.
11161184
*/
1117-
private def doMaintenance(): Unit = {
1185+
private def doMaintenance(storeConf: StateStoreConf): Unit = {
11181186
logDebug("Doing maintenance")
11191187
if (SparkEnv.get == null) {
11201188
throw new IllegalStateException("SparkEnv not active, cannot do maintenance on StateStores")
11211189
}
1190+
1191+
// Providers that couldn't be processed now and need to be added back to the queue
1192+
val providersToRequeue = new ArrayBuffer[(StateStoreProviderId, StateStoreProvider)]()
1193+
1194+
// unloadedProvidersToClose are StateStoreProviders that have been removed from
1195+
// loadedProviders, and can now be processed for maintenance. This queue contains
1196+
// providers for which we weren't able to process for maintenance on the previous iteration
1197+
while (!unloadedProvidersToClose.isEmpty) {
1198+
val (providerId, provider) = unloadedProvidersToClose.poll()
1199+
1200+
if (processThisPartition(providerId)) {
1201+
submitMaintenanceWorkForProvider(
1202+
providerId, provider, storeConf, MaintenanceTaskType.FromUnloadedProvidersQueue)
1203+
} else {
1204+
providersToRequeue += ((providerId, provider))
1205+
}
1206+
}
1207+
1208+
if (providersToRequeue.nonEmpty) {
1209+
logInfo(log"Had to requeue ${MDC(LogKeys.SIZE, providersToRequeue.size)} providers " +
1210+
log"for maintenance in doMaintenance")
1211+
}
1212+
1213+
providersToRequeue.foreach(unloadedProvidersToClose.offer)
1214+
11221215
loadedProviders.synchronized {
11231216
loadedProviders.toSeq
11241217
}.foreach { case (id, provider) =>
11251218
if (processThisPartition(id)) {
1126-
maintenanceThreadPool.execute(() => {
1127-
val startTime = System.currentTimeMillis()
1128-
try {
1129-
provider.doMaintenance()
1130-
if (!verifyIfStoreInstanceActive(id)) {
1131-
unload(id)
1132-
logInfo(log"Unloaded ${MDC(LogKeys.STATE_STORE_PROVIDER, provider)}")
1133-
}
1134-
} catch {
1135-
case NonFatal(e) =>
1136-
logWarning(log"Error managing ${MDC(LogKeys.STATE_STORE_PROVIDER, provider)}, " +
1137-
log"unloading state store provider", e)
1138-
// When we get a non-fatal exception, we just unload the provider.
1139-
//
1140-
// By not bubbling the exception to the maintenance task thread or the query execution
1141-
// thread, it's possible for a maintenance thread pool task to continue failing on
1142-
// the same partition. Additionally, if there is some global issue that will cause
1143-
// all maintenance thread pool tasks to fail, then bubbling the exception and
1144-
// stopping the pool is faster than waiting for all tasks to see the same exception.
1145-
//
1146-
// However, we assume that repeated failures on the same partition and global issues
1147-
// are rare. The benefit to unloading just the partition with an exception is that
1148-
// transient issues on a given provider do not affect any other providers; so, in
1149-
// most cases, this should be a more performant solution.
1150-
unload(id)
1151-
} finally {
1152-
val duration = System.currentTimeMillis() - startTime
1153-
val logMsg =
1154-
log"Finished maintenance task for " +
1155-
log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}" +
1156-
log" in elapsed_time=${MDC(LogKeys.TIME_UNITS, duration)}\n"
1157-
if (duration > 5000) {
1158-
logInfo(logMsg)
1159-
} else {
1160-
logDebug(logMsg)
1161-
}
1162-
maintenanceThreadPoolLock.synchronized {
1163-
maintenancePartitions.remove(id)
1164-
}
1165-
}
1166-
})
1219+
submitMaintenanceWorkForProvider(
1220+
id, provider, storeConf, MaintenanceTaskType.FromLoadedProviders)
11671221
} else {
11681222
logInfo(log"Not processing partition ${MDC(LogKeys.PARTITION_ID, id)} " +
11691223
log"for maintenance because it is currently " +
@@ -1172,6 +1226,108 @@ object StateStore extends Logging {
11721226
}
11731227
}
11741228

1229+
/**
1230+
* Submits maintenance work for a provider to the maintenance thread pool.
1231+
*
1232+
* @param id The StateStore provider ID to perform maintenance on
1233+
* @param provider The StateStore provider instance
1234+
*/
1235+
private def submitMaintenanceWorkForProvider(
1236+
id: StateStoreProviderId,
1237+
provider: StateStoreProvider,
1238+
storeConf: StateStoreConf,
1239+
source: MaintenanceTaskType = FromLoadedProviders): Unit = {
1240+
maintenanceThreadPool.execute(() => {
1241+
val startTime = System.currentTimeMillis()
1242+
// Determine if we can process this partition based on the source
1243+
val canProcessThisPartition = source match {
1244+
case FromTaskThread =>
1245+
// Provider from task thread needs to wait for lock
1246+
// We potentially need to wait for ongoing maintenance to finish processing
1247+
// this partition
1248+
val timeoutMs = storeConf.stateStoreMaintenanceProcessingTimeout * 1000
1249+
val ableToProcessNow = awaitProcessThisPartition(id, timeoutMs)
1250+
if (!ableToProcessNow) {
1251+
// Add to queue for later processing if we can't process now
1252+
// This will be resubmitted for maintenance later by the background maintenance task
1253+
unloadedProvidersToClose.add((id, provider))
1254+
}
1255+
ableToProcessNow
1256+
1257+
case FromUnloadedProvidersQueue =>
1258+
// Provider from queue can be processed immediately
1259+
// (we've already removed it from loadedProviders)
1260+
true
1261+
1262+
case FromLoadedProviders =>
1263+
// Provider from loadedProviders can be processed immediately
1264+
// as it's in maintenancePartitions
1265+
true
1266+
}
1267+
1268+
if (canProcessThisPartition) {
1269+
val awaitingPartitionDuration = System.currentTimeMillis() - startTime
1270+
try {
1271+
provider.doMaintenance()
1272+
// Handle unloading based on source
1273+
source match {
1274+
case FromTaskThread | FromUnloadedProvidersQueue =>
1275+
// Provider already removed from loadedProviders, just close it
1276+
removeFromLoadedProvidersAndClose(id, Some(provider))
1277+
1278+
case FromLoadedProviders =>
1279+
// Check if provider should be unloaded
1280+
if (!verifyIfStoreInstanceActive(id)) {
1281+
removeFromLoadedProvidersAndClose(id)
1282+
}
1283+
}
1284+
logInfo(log"Unloaded ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}")
1285+
} catch {
1286+
case NonFatal(e) =>
1287+
logWarning(log"Error doing maintenance on provider:" +
1288+
log" ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}. " +
1289+
log"Could not unload state store provider", e)
1290+
// When we get a non-fatal exception, we just unload the provider.
1291+
//
1292+
// By not bubbling the exception to the maintenance task thread or the query execution
1293+
// thread, it's possible for a maintenance thread pool task to continue failing on
1294+
// the same partition. Additionally, if there is some global issue that will cause
1295+
// all maintenance thread pool tasks to fail, then bubbling the exception and
1296+
// stopping the pool is faster than waiting for all tasks to see the same exception.
1297+
//
1298+
// However, we assume that repeated failures on the same partition and global issues
1299+
// are rare. The benefit to unloading just the partition with an exception is that
1300+
// transient issues on a given provider do not affect any other providers; so, in
1301+
// most cases, this should be a more performant solution.
1302+
source match {
1303+
case FromTaskThread | FromUnloadedProvidersQueue =>
1304+
removeFromLoadedProvidersAndClose(id, Some(provider))
1305+
1306+
case FromLoadedProviders =>
1307+
removeFromLoadedProvidersAndClose(id)
1308+
}
1309+
} finally {
1310+
val duration = System.currentTimeMillis() - startTime
1311+
val logMsg =
1312+
log"Finished maintenance task for " +
1313+
log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}" +
1314+
log" in elapsed_time=${MDC(LogKeys.TIME_UNITS, duration)}" +
1315+
log" and awaiting_partition_time=" +
1316+
log"${MDC(LogKeys.TIME_UNITS, awaitingPartitionDuration)}\n"
1317+
if (duration > 5000) {
1318+
logInfo(logMsg)
1319+
} else {
1320+
logDebug(logMsg)
1321+
}
1322+
maintenanceThreadPoolLock.synchronized {
1323+
maintenancePartitions.remove(id)
1324+
maintenanceThreadPoolLock.notifyAll()
1325+
}
1326+
}
1327+
}
1328+
})
1329+
}
1330+
11751331
private def reportActiveStoreInstance(
11761332
storeProviderId: StateStoreProviderId,
11771333
otherProviderIds: Seq[StateStoreProviderId]): Seq[StateStoreProviderId] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class StateStoreConf(
4040
*/
4141
val stateStoreMaintenanceShutdownTimeout: Long = sqlConf.stateStoreMaintenanceShutdownTimeout
4242

43+
val stateStoreMaintenanceProcessingTimeout: Long = sqlConf.stateStoreMaintenanceProcessingTimeout
44+
4345
/**
4446
* Minimum number of delta files in a chain after which HDFSBackedStateStore will
4547
* consider generating a snapshot.

0 commit comments

Comments
 (0)