Skip to content

Commit e39f13b

Browse files
authored
Create LakeFS commits RDD directly without using an input format (#9657)
* Create LakeFS commits RDD directly without using an input format Garbage collection (and all other uses) use LakeFSContext.newRDD to create the "ranges RDD". Creating them explicitly with Spark operators means Spark can parallelize reading all metaranges and ranges. <h2>How much faster?</h2> I have a small repo with many small commits. I enabled GC for it. Here are summaries from two sample mark-only runs. <h3>Direct RDD (this code)</h3> Runtime: 2m33s ```json { "run_id": "g4uk6erfnfus73frbnqg", "success": true, "first_slice": "g5adr8f5pvec73cpia80", "start_time": "2025-11-10T10:34:37.245361091Z", "cutoff_time": "2025-11-10T04:34:37.243Z", "num_deleted_objects": 147942 } ``` <h3>File format RDD (previous code)</h3> Runtime: 3m52s ```json { "run_id": "g4uinaarakss73aoeel0", "success": true, "first_slice": "g5adr8f5pvec73cpia80", "start_time": "2025-11-10T12:15:11.097697745Z", "cutoff_time": "2025-11-10T06:15:11.096Z", "num_deleted_objects": 147942 } ``` <h3>Summary</h3> - The same number of objects were marked for deletion. - The _same_ objects were marked for deletion on both. - New code takes 0.65 the time of the old code. * scalafmt D'oh. * [bug] Delete file, close SSTableReader after iterating All errors during closing are _logged_ but do not fail the task: these are readonly objects, so bad closes can do no more than leak (on "reasonable" systems). Flagged by **Copilot**, hurrah for verifiable actionable suggestions! * Also parallelize directory listing Read objects in parallel: - from directory listing; - from commits Default parallelism is no good for either of these, because it is based on # of CPUs - and we want a _lot_ more. New configuration option `lakefs.job.range_read_parallelism` configures this parallelism. * scalafmt
1 parent 8232830 commit e39f13b

File tree

6 files changed

+111
-23
lines changed

6 files changed

+111
-23
lines changed

clients/spark/src/main/scala/io/treeverse/clients/LakeFSContext.scala

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@ package io.treeverse.clients
22

33
import io.treeverse.lakefs.catalog.Entry
44
import org.apache.commons.lang3.StringUtils
5-
import org.apache.hadoop.conf.Configuration
5+
import org.apache.hadoop.fs.Path
66
import org.apache.hadoop.mapred.InvalidJobConfException
7-
import org.apache.spark.SparkContext
7+
import org.apache.spark.{SparkContext, TaskContext}
88
import org.apache.spark.rdd.RDD
99
import org.apache.spark.sql.types._
1010
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
11+
import org.apache.spark.util.SerializableConfiguration
12+
import org.slf4j.{Logger, LoggerFactory}
1113

14+
import java.io.File
1215
import java.util.concurrent.TimeUnit
1316

1417
object LakeFSJobParams {
@@ -61,6 +64,8 @@ class LakeFSJobParams private (
6164
}
6265

6366
object LakeFSContext {
67+
private val logger: Logger = LoggerFactory.getLogger(getClass.toString)
68+
6469
val LAKEFS_CONF_API_URL_KEY = "lakefs.api.url"
6570
val LAKEFS_CONF_API_ACCESS_KEY_KEY = "lakefs.api.access_key"
6671
val LAKEFS_CONF_API_SECRET_KEY_KEY = "lakefs.api.secret_key"
@@ -71,6 +76,9 @@ object LakeFSContext {
7176
val LAKEFS_CONF_JOB_COMMIT_IDS_KEY = "lakefs.job.commit_ids"
7277
val LAKEFS_CONF_JOB_SOURCE_NAME_KEY = "lakefs.job.source_name"
7378

79+
// Read parallelism. Defaults to default parallelism.
80+
val LAKEFS_CONF_JOB_RANGE_READ_PARALLELISM = "lakefs.job.range_read_parallelism"
81+
7482
val LAKEFS_CONF_GC_NUM_COMMIT_PARTITIONS = "lakefs.gc.commit.num_partitions"
7583
val LAKEFS_CONF_GC_NUM_RANGE_PARTITIONS = "lakefs.gc.range.num_partitions"
7684
val LAKEFS_CONF_GC_NUM_ADDRESS_PARTITIONS = "lakefs.gc.address.num_partitions"
@@ -108,15 +116,14 @@ object LakeFSContext {
108116
val DEFAULT_LAKEFS_CONF_GC_S3_MIN_BACKOFF_SECONDS = 1
109117
val DEFAULT_LAKEFS_CONF_GC_S3_MAX_BACKOFF_SECONDS = 120
110118

119+
val metarangeReaderGetter = SSTableReader.forMetaRange _
120+
111121
def newRDD(
112122
sc: SparkContext,
113123
params: LakeFSJobParams
114124
): RDD[(Array[Byte], WithIdentifier[Entry])] = {
115-
val inputFormatClass =
116-
if (params.commitIDs.nonEmpty) classOf[LakeFSCommitInputFormat]
117-
else classOf[LakeFSAllRangesInputFormat]
125+
val conf = sc.hadoopConfiguration
118126

119-
val conf = new Configuration(sc.hadoopConfiguration)
120127
conf.set(LAKEFS_CONF_JOB_REPO_NAME_KEY, params.repoName)
121128
conf.setStrings(LAKEFS_CONF_JOB_COMMIT_IDS_KEY, params.commitIDs.toArray: _*)
122129

@@ -131,12 +138,77 @@ object LakeFSContext {
131138
throw new InvalidJobConfException(s"$LAKEFS_CONF_API_SECRET_KEY_KEY must not be empty")
132139
}
133140
conf.set(LAKEFS_CONF_JOB_SOURCE_NAME_KEY, params.sourceName)
134-
sc.newAPIHadoopRDD(
135-
conf,
136-
inputFormatClass,
137-
classOf[Array[Byte]],
138-
classOf[WithIdentifier[Entry]]
141+
142+
val apiConf = APIConfigurations(
143+
conf.get(LAKEFS_CONF_API_URL_KEY),
144+
conf.get(LAKEFS_CONF_API_ACCESS_KEY_KEY),
145+
conf.get(LAKEFS_CONF_API_SECRET_KEY_KEY),
146+
conf.get(LAKEFS_CONF_API_CONNECTION_TIMEOUT_SEC_KEY),
147+
conf.get(LAKEFS_CONF_API_READ_TIMEOUT_SEC_KEY),
148+
conf.get(LAKEFS_CONF_JOB_SOURCE_NAME_KEY, "input_format")
139149
)
150+
val repoName = conf.get(LAKEFS_CONF_JOB_REPO_NAME_KEY)
151+
152+
// This can go to executors.
153+
val serializedConf = new SerializableConfiguration(conf)
154+
155+
val parallelism = conf.getInt(LAKEFS_CONF_JOB_RANGE_READ_PARALLELISM, sc.defaultParallelism)
156+
157+
// ApiClient is not serializable, so create a new one for each partition on its executor.
158+
// (If we called X.flatMap directly, we would fetch the client from the cache for each
159+
// range, which is a bit too much.)
160+
161+
// TODO(ariels): Unify with similar code in LakeFSInputFormat.getSplits
162+
val ranges = sc
163+
.parallelize(params.commitIDs.toSeq, parallelism)
164+
.mapPartitions(commits => {
165+
val apiClient = ApiClient.get(apiConf)
166+
val conf = serializedConf.value
167+
commits.flatMap(commitID => {
168+
val metaRangeURL = apiClient.getMetaRangeURL(repoName, commitID)
169+
if (metaRangeURL == "") {
170+
// a commit with no meta range is an empty commit.
171+
// this only happens for the first commit in the repository.
172+
None
173+
} else {
174+
val rangesReader = metarangeReaderGetter(conf, metaRangeURL, true)
175+
rangesReader
176+
.newIterator()
177+
.map(rd => new Range(new String(rd.id), rd.message.estimatedSize))
178+
}
179+
})
180+
})
181+
.distinct
182+
183+
ranges.mapPartitions(ranges => {
184+
val apiClient = ApiClient.get(apiConf)
185+
val conf = serializedConf.value
186+
ranges.flatMap((range: Range) => {
187+
val path = new Path(apiClient.getRangeURL(repoName, range.id))
188+
val fs = path.getFileSystem(conf)
189+
val localFile = File.createTempFile("lakefs.", ".range")
190+
191+
fs.copyToLocalFile(false, path, new Path(localFile.getAbsolutePath), true)
192+
val companion = Entry.messageCompanion
193+
// localFile owned by sstableReader which will delete it when closed.
194+
val sstableReader = new SSTableReader(localFile.getAbsolutePath, companion, true)
195+
Option(TaskContext.get()).foreach(_.addTaskCompletionListener((tc: TaskContext) => {
196+
try {
197+
sstableReader.close()
198+
} catch {
199+
case e: Exception => {
200+
logger.warn(s"close SSTable reader for $localFile (keep going): $e")
201+
}
202+
}
203+
tc
204+
}))
205+
// TODO(ariels): Do we need to validate that this reader is good? Assume _not_, this is
206+
// not InputFormat code so it should have slightly nicer error reports.
207+
sstableReader
208+
.newIterator()
209+
.map((entry) => (entry.key, new WithIdentifier(entry.id, entry.message, range.id)))
210+
})
211+
})
140212
}
141213

142214
/** Returns all entries in all ranges of the given commit, as an RDD.

clients/spark/src/main/scala/io/treeverse/clients/LakeFSInputFormat.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,10 @@ abstract class LakeFSBaseInputFormat extends InputFormat[Array[Byte], WithIdenti
169169
new EntryRecordReader(Entry.messageCompanion)
170170
}
171171
}
172-
private class Range(val id: String, val estimatedSize: Long) {
172+
173+
class Range(val id: String, val estimatedSize: Long) extends Serializable {
174+
// non-private so Spark will serialize it.
175+
173176
override def hashCode(): Int = {
174177
id.hashCode()
175178
}

clients/spark/src/main/scala/io/treeverse/clients/SSTableReader.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import org.apache.hadoop.fs.Path
99
import org.apache.spark.TaskContext
1010
import scalapb.{GeneratedMessage, GeneratedMessageCompanion}
1111

12+
import org.slf4j.{Logger, LoggerFactory}
1213
import java.io.{ByteArrayInputStream, Closeable, DataInputStream, File}
1314

1415
class Item[T](val key: Array[Byte], val id: Array[Byte], val message: T)
@@ -101,6 +102,8 @@ class SSTableReader[Proto <: GeneratedMessage with scalapb.Message[Proto]] priva
101102
val companion: GeneratedMessageCompanion[Proto],
102103
val own: Boolean = true
103104
) extends Closeable {
105+
private val logger: Logger = LoggerFactory.getLogger(getClass.toString)
106+
104107
private val fp = new java.io.RandomAccessFile(file, "r")
105108
private val reader = new BlockReadableFile(fp)
106109

@@ -110,7 +113,13 @@ class SSTableReader[Proto <: GeneratedMessage with scalapb.Message[Proto]] priva
110113
def close(): Unit = {
111114
fp.close()
112115
if (own) {
113-
file.delete()
116+
try {
117+
file.delete()
118+
} catch {
119+
case e: Exception => {
120+
logger.warn(s"delete owned file ${file.getName()} (keep going): $e")
121+
}
122+
}
114123
}
115124
}
116125

clients/spark/src/main/scala/io/treeverse/gc/DataLister.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ import scala.collection.mutable.ListBuffer
1111
*/
1212
abstract class DataLister {
1313
@transient lazy val spark: SparkSession = SparkSession.active
14-
def listData(configMapper: ConfigMapper, path: Path): DataFrame
14+
def listData(configMapper: ConfigMapper, path: Path, parallelism: Int): DataFrame
1515
}
1616

1717
class NaiveDataLister extends DataLister {
18-
override def listData(configMapper: ConfigMapper, path: Path): DataFrame = {
18+
override def listData(configMapper: ConfigMapper, path: Path, parallelism: Int): DataFrame = {
1919
import spark.implicits._
2020
val fs = path.getFileSystem(configMapper.configuration)
2121
val dataIt = fs.listFiles(path, false)
@@ -24,7 +24,7 @@ class NaiveDataLister extends DataLister {
2424
val fileStatus = dataIt.next()
2525
dataList += ((fileStatus.getPath.getName, fileStatus.getModificationTime))
2626
}
27-
dataList.toDF("base_address", "last_modified")
27+
dataList.toDF("base_address", "last_modified").repartition(parallelism)
2828
}
2929
}
3030

@@ -47,7 +47,7 @@ class ParallelDataLister extends DataLister with Serializable {
4747
}
4848
}
4949

50-
override def listData(configMapper: ConfigMapper, path: Path): DataFrame = {
50+
override def listData(configMapper: ConfigMapper, path: Path, parallelism: Int): DataFrame = {
5151
import spark.implicits._
5252
val slices = listPath(configMapper, path)
5353
val objectsPath = if (path.toString.endsWith("/")) path.toString else path.toString + "/"
@@ -63,6 +63,7 @@ class ParallelDataLister extends DataLister with Serializable {
6363
.map(_.path)
6464
.toSeq
6565
.toDF("slice_id")
66+
.repartition(parallelism)
6667
.withColumn("udf", explode(objectsUDF(col("slice_id"))))
6768
.withColumn("base_address", concat(col("slice_id"), lit("/"), col("udf._1")))
6869
.withColumn("last_modified", col("udf._2"))

clients/spark/src/main/scala/io/treeverse/gc/GarbageCollection.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import org.apache.commons.lang3.time.DateUtils
66
import org.apache.hadoop.fs.Path
77
import org.apache.spark.sql.functions._
88
import org.apache.spark.sql.{DataFrame, SparkSession}
9+
import org.apache.spark.storage.StorageLevel
910
import org.json4s.JsonDSL._
1011
import org.json4s._
1112
import org.json4s.native.JsonMethods._
@@ -43,6 +44,8 @@ object GarbageCollection {
4344
val sc = spark.sparkContext
4445
val oldDataPath = new Path(storageNamespace)
4546
val dataPath = new Path(storageNamespace, DATA_PREFIX)
47+
val parallelism =
48+
sc.hadoopConfiguration.getInt(LAKEFS_CONF_JOB_RANGE_READ_PARALLELISM, sc.defaultParallelism)
4649

4750
val configMapper = new ConfigMapper(
4851
sc.broadcast(
@@ -54,7 +57,7 @@ object GarbageCollection {
5457
)
5558
)
5659
// Read objects from data path (new repository structure)
57-
var dataDF = new ParallelDataLister().listData(configMapper, dataPath)
60+
var dataDF = new ParallelDataLister().listData(configMapper, dataPath, parallelism)
5861
dataDF = dataDF
5962
.withColumn(
6063
"address",
@@ -65,7 +68,7 @@ object GarbageCollection {
6568

6669
// TODO (niro): implement parallel lister for old repositories (https://github.com/treeverse/lakeFS/issues/4620)
6770
val oldDataDF = new NaiveDataLister()
68-
.listData(configMapper, oldDataPath)
71+
.listData(configMapper, oldDataPath, parallelism)
6972
.withColumn("address", col("base_address"))
7073
.filter(!col("address").isin(excludeFromOldData: _*))
7174
dataDF = dataDF.union(oldDataDF).filter(col("last_modified") < before.getTime)
@@ -195,7 +198,7 @@ object GarbageCollection {
195198
.repartition(dataDF.col("address"))
196199
.except(committedDF)
197200
.except(uncommittedDF)
198-
.cache()
201+
.persist(StorageLevel.MEMORY_AND_DISK)
199202

200203
committedDF.unpersist()
201204
uncommittedDF.unpersist()

clients/spark/src/test/scala/io/treeverse/gc/DataListerSpec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class ParallelDataListerSpec
5252
)
5353
)
5454
val df =
55-
new ParallelDataLister().listData(configMapper, path).sort("base_address")
55+
new ParallelDataLister().listData(configMapper, path, 3).sort("base_address")
5656
df.count should be(100)
5757
val slices =
5858
df.select(substring(col("base_address"), 0, 7).as("slice_id"))
@@ -81,7 +81,7 @@ class ParallelDataListerSpec
8181
)
8282
)
8383
val df =
84-
new ParallelDataLister().listData(configMapper, path).sort("base_address")
84+
new ParallelDataLister().listData(configMapper, path, 3).sort("base_address")
8585
df.count() should be(1)
8686
df.head.getString(0) should be(s"$sliceID/$filename")
8787
})
@@ -127,7 +127,7 @@ class NaiveDataListerSpec
127127
HadoopUtils.getHadoopConfigurationValues(spark.sparkContext.hadoopConfiguration)
128128
)
129129
)
130-
val df = new NaiveDataLister().listData(configMapper, path).sort("base_address")
130+
val df = new NaiveDataLister().listData(configMapper, path, 3).sort("base_address")
131131
df.count should be(10)
132132
df.sort("base_address").head.getString(0) should be("object01")
133133
df.head.getString(0) should be("object01")

0 commit comments

Comments
 (0)