From 0d3f75ed99c4b0e54e8a7ba42ed93fc8382b76a7 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Tue, 2 Dec 2025 00:33:19 +0800 Subject: [PATCH 1/2] Rollback succeeding shuffle map stages when shuffle checksum mismatch detected --- .../apache/spark/scheduler/DAGScheduler.scala | 223 +++++++++++++----- .../org/apache/spark/scheduler/Stage.scala | 16 ++ .../spark/scheduler/DAGSchedulerSuite.scala | 88 ++++++- 3 files changed, 270 insertions(+), 57 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 7c8bea31334b..292d95aceb12 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1560,42 +1560,27 @@ private[spark] class DAGScheduler( // `findMissingPartitions()` returns all partitions every time. stage match { case sms: ShuffleMapStage if !sms.isAvailable => - val needFullStageRetry = if (sms.shuffleDep.checksumMismatchFullRetryEnabled) { - // When the parents of this stage are indeterminate (e.g., some parents are not - // checkpointed and checksum mismatches are detected), the output data of the parents - // may have changed due to task retries. For correctness reason, we need to - // retry all tasks of the current stage. The legacy way of using current stage's - // deterministic level to trigger full stage retry is not accurate. - stage.isParentIndeterminate - } else { - if (stage.isIndeterminate) { - // already executed at least once - if (sms.getNextAttemptId > 0) { - // While we previously validated possible rollbacks during the handling of a FetchFailure, - // where we were fetching from an indeterminate source map stages, this later check - // covers additional cases like recalculating an indeterminate stage after an executor - // loss. Moreover, because this check occurs later in the process, if a result stage task - // has successfully completed, we can detect this and abort the job, as rolling back a - // result stage is not possible. - val stagesToRollback = collectSucceedingStages(sms) - abortStageWithInvalidRollBack(stagesToRollback) - // stages which cannot be rolled back were aborted which leads to removing the - // the dependant job(s) from the active jobs set - val numActiveJobsWithStageAfterRollback = - activeJobs.count(job => stagesToRollback.contains(job.finalStage)) - if (numActiveJobsWithStageAfterRollback == 0) { - logInfo(log"All jobs depending on the indeterminate stage " + - log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.") - return - } + if (!sms.shuffleDep.checksumMismatchFullRetryEnabled && stage.isIndeterminate) { + // already executed at least once + if (sms.getNextAttemptId > 0) { + // While we previously validated possible rollbacks during the handling of a FetchFailure, + // where we were fetching from an indeterminate source map stages, this later check + // covers additional cases like recalculating an indeterminate stage after an executor + // loss. Moreover, because this check occurs later in the process, if a result stage task + // has successfully completed, we can detect this and abort the job, as rolling back a + // result stage is not possible. + val stagesToRollback = collectSucceedingStages(sms) + abortStagesUnableToRollback(stagesToRollback) + // stages which cannot be rolled back were aborted which leads to removing the + // the dependant job(s) from the active jobs set + val numActiveJobsWithStageAfterRollback = + activeJobs.count(job => stagesToRollback.contains(job.finalStage)) + if (numActiveJobsWithStageAfterRollback == 0) { + logInfo(log"All jobs depending on the indeterminate stage " + + log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.") + return } - true - } else { - false } - } - - if (needFullStageRetry) { mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) sms.shuffleDep.newShuffleMergeState() } @@ -1913,16 +1898,127 @@ private[spark] class DAGScheduler( /** * If a map stage is non-deterministic, the map tasks of the stage may return different result - * when re-try. To make sure data correctness, we need to re-try all the tasks of its succeeding - * stages, as the input data may be changed after the map tasks are re-tried. For stages where - * rollback and retry all tasks are not possible, we will need to abort the stages. + * when re-try. To make sure data correctness, we need to clean up shuffles to make sure succeeding + * stages will be resubmitted and re-try all the tasks, as the input data may be changed after + * the map tasks are re-tried. For stages where rollback and retry all tasks are not possible, + * we will need to abort the stages. + */ + private[scheduler] def rollbackSucceedingStages(mapStage: ShuffleMapStage): Unit = { + val stagesToRollback = collectSucceedingStages(mapStage).filterNot(_ == mapStage) + val stagesCanRollback = abortStagesUnableToRollback(stagesToRollback) + // stages which cannot be rolled back were aborted which leads to removing the + // the dependant job(s) from the active jobs set, there could be no active jobs + // left depending on the indeterminate stage and hence no need to roll back any stages. + val numActiveJobsWithStageAfterRollback = + activeJobs.count(job => stagesToRollback.contains(job.finalStage)) + if (numActiveJobsWithStageAfterRollback == 0) { + logInfo(log"All jobs depending on the indeterminate stage " + + log"(${MDC(STAGE_ID, mapStage.id)}) were aborted.") + } else { + // Mark rollback attempt to identify elder attempts which could consume inconsistent data, + // the results from these attempts should be ignored. + // Rollback the running stages first to avoid triggering more fetch failures. + stagesToRollback.toSeq.sortBy(!runningStages.contains(_)).foreach { + case sms: ShuffleMapStage => + rollbackShuffleMapStage(sms, "rolling back due to indeterminate " + + s"output of shuffle map stage $mapStage") + sms.markAsRollingBack() + + case rs: ResultStage => + rs.markAsRollingBack() + } + + logInfo(log"The shuffle map stage ${MDC(STAGE, mapStage)} with indeterminate output " + + log"was retried, we will roll back and rerun its succeeding " + + log"stages: ${MDC(STAGES, stagesCanRollback)}") + } + } + + /** + * Roll back the given shuffle map stage: + * 1. If the stage is running, cancel the stage and kill all running tasks. Clean up the shuffle + * output resubmit it if it's not exceeded max retries. + * 2. If the stage is not running but having output generated, clean up the shuffle output to + * ensure the stage will be re-executed with fully retry. + * + * @param sms the shuffle map stage to roll back + * @param reason the reason for rolling back + */ + private def rollbackShuffleMapStage(sms: ShuffleMapStage, reason: String): Unit = { + logInfo(log"Rolling back ${MDC(STAGE, sms)} due to indeterminate rollback") + val clearShuffle = if (runningStages.contains(sms)) { + logInfo(log"Stage ${MDC(STAGE, sms)} is running, marking it as failed and " + + log"resubmit if allowed") + cancelStageAndTryResubmit(sms, reason) + } else { + true + } + + // Clean up shuffle outputs in case the stage is not aborted to ensure the stage + // will be re-executed. + if (clearShuffle) { + logInfo(log"Cleaning up shuffle for stage ${MDC(STAGE, sms)} to ensure re-execution") + mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) + sms.shuffleDep.newShuffleMergeState() + } + } + + /** + * Cancel the give running shuffle map stage, killing all running tasks, resubmit if it doesn't + * exceed max retries. + * + * @param stage the stage to cancel and resubmit + * @param reason the reason for the operation + * @return true if the stage is successfully cancelled and resubmitted, otherwise false */ - private[scheduler] def abortUnrollbackableStages(mapStage: ShuffleMapStage): Unit = { - val stagesToRollback = collectSucceedingStages(mapStage) - val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback) - logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output " + - log"was failed, we will roll back and rerun below stages which include itself and all its " + - log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}") + private def cancelStageAndTryResubmit(stage: ShuffleMapStage, reason: String): Boolean = { + assert(runningStages.contains(stage), "stage must be running to be cancelled and resubmitted") + try { + // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask. + val job = jobIdToActiveJob.get(stage.firstJobId) + val shouldInterrupt = job.exists(j => shouldInterruptTaskThread(j)) + taskScheduler.killAllTaskAttempts(stage.id, shouldInterrupt, reason) + } catch { + case e: UnsupportedOperationException => + logWarning(log"Could not kill all tasks for stage ${MDC(STAGE_ID, stage.id)}", e) + abortStage(stage, "Rollback failed due to: Not able to kill running tasks for stage " + + s"$stage (${stage.name})", Some(e)) + return false + } + + stage.failedAttemptIds.add(stage.latestInfo.attemptNumber()) + val shouldAbortStage = stage.failedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + markStageAsFinished(stage, Some(reason), willRetry = !shouldAbortStage) + + if (shouldAbortStage) { + val abortMessage = if (disallowStageRetryForTest) { + "Stage will not retry stage due to testing config. Most recent failure " + + s"reason: $reason" + } else { + s"$stage (${stage.name}) has failed the maximum allowable number of " + + s"times: $maxConsecutiveStageAttempts. Most recent failure reason: $reason" + } + abortStage(stage, s"rollback failed due to: $abortMessage", None) + } else { + // In case multiple task failures triggered for a single stage attempt, ensure we only + // resubmit the failed stage once. + val noResubmitEnqueued = !failedStages.contains(stage) + failedStages += stage + if (noResubmitEnqueued) { + logInfo(log"Resubmitting ${MDC(FAILED_STAGE, stage)} " + + log"(${MDC(FAILED_STAGE_NAME, stage.name)}) due to rollback.") + messageScheduler.schedule( + new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, + DAGScheduler.RESUBMIT_TIMEOUT, + TimeUnit.MILLISECONDS + ) + } + } + + !shouldAbortStage } /** @@ -1990,7 +2086,21 @@ private[spark] class DAGScheduler( // tasks complete, they still count and we can mark the corresponding partitions as // finished if the stage is determinate. Here we notify the task scheduler to skip running // tasks for the same partition to save resource. - if (!stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber()) { + def stageWithChecksumMismatchFullRetryEnabled(stage: Stage): Boolean = { + stage match { + case s: ShuffleMapStage => s.shuffleDep.checksumMismatchFullRetryEnabled + case _ => stage.parents.exists(stageWithChecksumMismatchFullRetryEnabled) + } + } + + // Ignore task completion for old attempt of indeterminate stage + val ignoreOldTaskAttempts = if (stageWithChecksumMismatchFullRetryEnabled(stage)) { + stage.maxAttemptIdToIgnore.exists(_ >= task.stageAttemptId) + } else { + stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber() + } + + if (!ignoreOldTaskAttempts && task.stageAttemptId < stage.latestInfo.attemptNumber()) { taskScheduler.notifyPartitionCompletion(stageId, task.partitionId) } @@ -2045,10 +2155,7 @@ private[spark] class DAGScheduler( case smt: ShuffleMapTask => val shuffleStage = stage.asInstanceOf[ShuffleMapStage] - // Ignore task completion for old attempt of indeterminate stage - val ignoreIndeterminate = stage.isIndeterminate && - task.stageAttemptId < stage.latestInfo.attemptNumber() - if (!ignoreIndeterminate) { + if (!ignoreOldTaskAttempts) { shuffleStage.pendingPartitions -= task.partitionId val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId @@ -2077,7 +2184,7 @@ private[spark] class DAGScheduler( shuffleStage.maxChecksumMismatchedId = smt.stageAttemptId if (shuffleStage.shuffleDep.checksumMismatchFullRetryEnabled && shuffleStage.isStageIndeterminate) { - abortUnrollbackableStages(shuffleStage) + rollbackSucceedingStages(shuffleStage) } } } @@ -2206,7 +2313,11 @@ private[spark] class DAGScheduler( // guaranteed to be determinate, so the input data of the reducers will not change // even if the map tasks are re-tried. if (mapStage.isIndeterminate && !mapStage.shuffleDep.checksumMismatchFullRetryEnabled) { - abortUnrollbackableStages(mapStage) + val stagesToRollback = collectSucceedingStages(mapStage) + val rollingBackStages = abortStagesUnableToRollback(stagesToRollback) + logInfo(log"The shuffle map stage ${MDC(STAGE, mapStage)} with indeterminate output " + + log"was failed, we will roll back and rerun below stages which include itself and all " + + log"its indeterminate child stages: ${MDC(STAGES, rollingBackStages)}") } // We expect one executor failure to trigger many FetchFailures in rapid succession, @@ -2342,9 +2453,13 @@ private[spark] class DAGScheduler( if (noResubmitEnqueued) { logInfo(log"Resubmitting ${MDC(FAILED_STAGE, failedStage)} " + log"(${MDC(FAILED_STAGE_NAME, failedStage.name)}) due to barrier stage failure.") - messageScheduler.schedule(new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + messageScheduler.schedule( + new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, + DAGScheduler.RESUBMIT_TIMEOUT, + TimeUnit.MILLISECONDS + ) } } } @@ -2396,7 +2511,7 @@ private[spark] class DAGScheduler( * @param stagesToRollback stages to roll back * @return Shuffle map stages which need and can be rolled back */ - private def abortStageWithInvalidRollBack(stagesToRollback: HashSet[Stage]): HashSet[Stage] = { + private def abortStagesUnableToRollback(stagesToRollback: HashSet[Stage]): HashSet[Stage] = { def generateErrorMessage(stage: Stage): String = { "A shuffle map stage with indeterminate output was failed and retried. " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 9bf604e9a83c..d8aaea013ee6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -84,6 +84,14 @@ private[scheduler] abstract class Stage( */ private[scheduler] var maxChecksumMismatchedId: Int = nextAttemptId + /** + * The max attempt id we should ignore results for this stage, indicating there are ancestor + * stages having been detected with checksum mismatches. This stage is probably also + * indeterminate, so we need to avoid completing the stage and the job with incorrect result + * by ignoring the task output from previous attempts which might consume inconsistent data + */ + private[scheduler] var maxAttemptIdToIgnore: Option[Int] = None + val name: String = callSite.shortForm val details: String = callSite.longForm @@ -108,6 +116,14 @@ private[scheduler] abstract class Stage( failedAttemptIds.clear() } + /** Mark the latest attempt as rollback */ + private[scheduler] def markAsRollingBack(): Unit = { + // Only if the stage has been submitted + if (getNextAttemptId > 0) { + maxAttemptIdToIgnore = Some(latestInfo.attemptNumber()) + } + } + /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ def makeNewStageAttempt( numPartitionsToCompute: Int, diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 6ec0ea320eaa..3751f9b9aa3c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -3421,11 +3421,12 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti stageId: Int, shuffleId: Int, numTasks: Int = 2, - checksumVal: Long = 0): Unit = { + checksumVal: Long = 0, + stageAttemptId: Int = 1): Unit = { assert(taskSets(taskSetIndex).stageId == stageId) - assert(taskSets(taskSetIndex).stageAttemptId == 1) + assert(taskSets(taskSetIndex).stageAttemptId == stageAttemptId) assert(taskSets(taskSetIndex).tasks.length == numTasks) - completeShuffleMapStageSuccessfully(stageId, 1, 2, checksumVal = checksumVal) + completeShuffleMapStageSuccessfully(stageId, stageAttemptId, 2, checksumVal = checksumVal) assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) } @@ -3835,6 +3836,87 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } } + test("SPARK-54556: ensure rollback all the succeeding stages and ignore stale task results " + + "when shuffle checksum mismatch detected") { + /** + * Construct the following RDD graph: + * + * ShuffleMapRdd1 (Indeterminate) + * / \ + * ShuffleMapRdd2 \ + * / | + * ShuffleMapRdd3 | + * \ | + * FinalRd + * + * While executing the result stage, shuffle fetch failed on shuffle1 and leading to executor + * loss and some map output of shuffle2 lost. + * Both stage 0 and stage 2 will be submitted. + * Checksum mismatch is detected when retrying stage 0. + * Retry task of stage 2 completed and should be ignored. + */ + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil) + val shuffleDep1 = new ShuffleDependency( + shuffleMapRdd1, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true) + val shuffleId1 = shuffleDep1.shuffleId + + val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker) + val shuffleDep2 = new ShuffleDependency( + shuffleMapRdd2, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true) + val shuffleId2 = shuffleDep2.shuffleId + + val shuffleMapRdd3 = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker) + val shuffleDep3 = new ShuffleDependency( + shuffleMapRdd3, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true) + val shuffleId3 = shuffleDep3.shuffleId + + val finalRdd = new MyRDD(sc, 2, List(shuffleDep1, shuffleDep3), tracker = mapOutputTracker) + + // Submit the job and complete the shuffle stages + submit(finalRdd, Array(0, 1)) + completeShuffleMapStageSuccessfully( + 0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100) + completeShuffleMapStageSuccessfully( + 1, 0, 2, Seq("hostC", "hostD"), checksumVal = 200) + completeShuffleMapStageSuccessfully( + 2, 0, 2, Seq("hostB", "hostC"), checksumVal = 300) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + assert(mapOutputTracker.findMissingPartitions(shuffleId3) === Some(Seq.empty)) + + // The first task of result stage 3 failed with FetchFailed. + runEvent(makeCompletionEvent( + taskSets(3).tasks(0), + FetchFailed(makeBlockManagerId("hostB"), shuffleId1, 0L, 0, 0, "ignored"), + null)) + assert(mapOutputTracker.findMissingPartitions(shuffleId3).nonEmpty) + + // Check status for all failedStages. + val failedStages = scheduler.failedStages.toSeq + assert(failedStages.map(_.id) === Seq(0, 3)) + scheduler.resubmitFailedStages() + // Check status for runningStages. + assert(scheduler.runningStages.map(_.id) === Set(0, 2)) + + // Complete the re-attempt of shuffle map stage 0(shuffleId1) with a different checksum. + completeShuffleMapStageSuccessfully(0, 1, 2, checksumVal = 101) + completeShuffleMapStageSuccessfully(2, 1, 2, checksumVal = 300) + // The result of stage 2 should be ignored + assert(mapOutputTracker.getNumAvailableOutputs(shuffleId3) === 0) + scheduler.resubmitFailedStages() + assert(scheduler.runningStages.map(_.id) === Set(1)) + + checkAndCompleteRetryStage(6, 1, shuffleId2, 2, checksumVal = 201) + checkAndCompleteRetryStage(7, 2, shuffleId3, 2, checksumVal = 301, stageAttemptId = 2) + completeAndCheckAnswer(taskSets(8), Seq((Success, 11), (Success, 12)), Map(0 -> 11, 1 -> 12)) + } + test("SPARK-27164: RDD.countApprox on empty RDDs schedules jobs which never complete") { val latch = new CountDownLatch(1) val jobListener = new SparkListener { From 78b0d9b956e8c35ba92a54ddb2592feaa39d5203 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Tue, 2 Dec 2025 09:51:11 +0800 Subject: [PATCH 2/2] revert unnecessary changes --- .../org/apache/spark/scheduler/DAGScheduler.scala | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 292d95aceb12..c95912db4aa9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -2453,13 +2453,9 @@ private[spark] class DAGScheduler( if (noResubmitEnqueued) { logInfo(log"Resubmitting ${MDC(FAILED_STAGE, failedStage)} " + log"(${MDC(FAILED_STAGE_NAME, failedStage.name)}) due to barrier stage failure.") - messageScheduler.schedule( - new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, - DAGScheduler.RESUBMIT_TIMEOUT, - TimeUnit.MILLISECONDS - ) + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) } } }