diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index ac15456f0c3d4..70ae8068a03a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -84,6 +84,8 @@ case class ParquetPartitionReaderFactory( private val datetimeRebaseModeInRead = options.datetimeRebaseModeInRead private val int96RebaseModeInRead = options.int96RebaseModeInRead + private val parquetReaderCallback = new ParquetReaderCallback() + private def getFooter(file: PartitionedFile): ParquetMetadata = { val conf = broadcastedConf.value.value if (aggregation.isDefined || enableVectorizedReader) { @@ -309,7 +311,8 @@ case class ParquetPartitionReaderFactory( reader, readDataSchema) val iter = new RecordReaderIterator(readerWithRowIndexes) // SPARK-23457 Register a task completion listener before `initialization`. - taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) + parquetReaderCallback.advanceFile(iter) + taskContext.foreach(parquetReaderCallback.initIfNotAlready) readerWithRowIndexes } @@ -337,8 +340,39 @@ case class ParquetPartitionReaderFactory( capacity) val iter = new RecordReaderIterator(vectorizedReader) // SPARK-23457 Register a task completion listener before `initialization`. - taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) + parquetReaderCallback.advanceFile(iter) + taskContext.foreach(parquetReaderCallback.initIfNotAlready) logDebug(s"Appending $partitionSchema $partitionValues") vectorizedReader } } + +/** + * A callback class to handle the cleanup of Parquet readers. + * + * This class is used to ensure that the Parquet readers are closed properly when the task + * completes, and it also allows for the initialization of the reader callback only once per task. + */ +private class ParquetReaderCallback extends Serializable { + private var init: Boolean = false + private var iter: RecordReaderIterator[_] = null + + def initIfNotAlready(taskContext: TaskContext): Unit = { + if (!init) { + taskContext.addTaskCompletionListener[Unit](_ => closeCurrent()) + init = true + } + } + + def advanceFile(iter: RecordReaderIterator[_]): Unit = { + closeCurrent() + + this.iter = iter + } + + def closeCurrent(): Unit = { + if (iter != null) { + iter.close() + } + } +}