diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java
index ff82e71bfd58..0fe47683bb0b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java
@@ -20,7 +20,7 @@
import org.apache.spark.annotation.Evolving;
/**
- * A mix-in interface for {@link Scan}. Data sources can implement this interface to
+ * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* support pushing down variant field access operations to the data source.
*
* When variant columns are accessed with specific field extractions (e.g., variant_get),
@@ -31,16 +31,17 @@
*
* - Optimizer analyzes the query plan and identifies variant field accesses
* - Optimizer calls {@link #pushVariantAccess} with the access information
- * - Data source validates and stores the variant access information
+ * - ScanBuilder validates and stores the variant access information
* - Optimizer retrieves pushed information via {@link #pushedVariantAccess}
- * - Data source uses the information to optimize reading in {@link #readSchema()}
+ *
- ScanBuilder builds a {@link Scan} with the variant access information
+ * - Scan uses the information to optimize reading in {@link Scan#readSchema()}
* and readers
*
*
* @since 4.1.0
*/
@Evolving
-public interface SupportsPushDownVariants extends Scan {
+public interface SupportsPushDownVariants extends ScanBuilder {
/**
* Pushes down variant field access information to the data source.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
index 2cf1a5e9b8cd..7ac50d738f04 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
@@ -26,10 +26,8 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, Subquery}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
-import org.apache.spark.sql.connector.read.{SupportsPushDownVariants, VariantAccessInfo}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -281,11 +279,6 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
relation @ LogicalRelationWithTable(
hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _), _)) =>
rewritePlan(p, projectList, filters, relation, hadoopFsRelation)
-
- case p@PhysicalOperation(projectList, filters,
- scanRelation @ DataSourceV2ScanRelation(
- relation, scan: SupportsPushDownVariants, output, _, _)) =>
- rewritePlanV2(p, projectList, filters, scanRelation, scan)
}
}
@@ -333,112 +326,10 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
hadoopFsRelation.sparkSession)
val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq)
- buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap)
- }
-
- // DataSource V2 rewrite method using SupportsPushDownVariants API
- // Key differences from V1 implementation:
- // 1. V2 uses DataSourceV2ScanRelation instead of LogicalRelation
- // 2. Uses SupportsPushDownVariants API instead of directly manipulating scan
- // 3. Schema is already resolved in scanRelation.output (no need for relation.resolve())
- // 4. Scan rebuilding is handled by the scan implementation via the API
- // Data sources like Delta and Iceberg can implement this API to support variant pushdown.
- private def rewritePlanV2(
- originalPlan: LogicalPlan,
- projectList: Seq[NamedExpression],
- filters: Seq[Expression],
- scanRelation: DataSourceV2ScanRelation,
- scan: SupportsPushDownVariants): LogicalPlan = {
- val variants = new VariantInRelation
-
- // Extract schema attributes from V2 scan relation
- val schemaAttributes = scanRelation.output
-
- // Construct schema for default value resolution
- val structSchema = StructType(schemaAttributes.map(a =>
- StructField(a.name, a.dataType, a.nullable, a.metadata)))
-
- val defaultValues = ResolveDefaultColumns.existenceDefaultValues(structSchema)
-
- // Add variant fields from the V2 scan schema
- for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
- variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
- }
- if (variants.mapping.isEmpty) return originalPlan
-
- // Collect requested fields from project list and filters
- projectList.foreach(variants.collectRequestedFields)
- filters.foreach(variants.collectRequestedFields)
-
- // If no variant columns remain after collection, return original plan
- if (variants.mapping.forall(_._2.isEmpty)) return originalPlan
-
- // Build VariantAccessInfo array for the API
- val variantAccessInfoArray = schemaAttributes.flatMap { attr =>
- variants.mapping.get(attr.exprId).flatMap(_.get(Nil)).map { fields =>
- // Build extracted schema for this variant column
- val extractedFields = fields.toArray.sortBy(_._2).map { case (field, ordinal) =>
- StructField(ordinal.toString, field.targetType, metadata = field.path.toMetadata)
- }
- val extractedSchema = if (extractedFields.isEmpty) {
- // Add placeholder field to avoid empty struct
- val placeholder = VariantMetadata("$.__placeholder_field__",
- failOnError = false, timeZoneId = "UTC")
- StructType(Array(StructField("0", BooleanType, metadata = placeholder.toMetadata)))
- } else {
- StructType(extractedFields)
- }
- new VariantAccessInfo(attr.name, extractedSchema)
- }
- }.toArray
-
- // Call the API to push down variant access
- if (variantAccessInfoArray.isEmpty) return originalPlan
-
- val pushed = scan.pushVariantAccess(variantAccessInfoArray)
- if (!pushed) return originalPlan
-
- // Get what was actually pushed
- val pushedVariantAccess = scan.pushedVariantAccess()
- if (pushedVariantAccess.isEmpty) return originalPlan
-
- // Build new attribute mapping based on pushed variant access
- val pushedColumnNames = pushedVariantAccess.map(_.columnName()).toSet
- val attributeMap = schemaAttributes.map { a =>
- if (pushedColumnNames.contains(a.name) && variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
- val newType = variants.rewriteType(a.exprId, a.dataType, Nil)
- val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)(
- qualifier = a.qualifier)
- (a.exprId, newAttr)
- } else {
- (a.exprId, a)
- }
- }.toMap
-
- val newOutput = scanRelation.output.map(a => attributeMap.getOrElse(a.exprId, a))
-
- // The scan implementation should have updated its readSchema() based on the pushed info
- // We just need to create a new scan relation with the updated output
- val newScanRelation = scanRelation.copy(
- output = newOutput
- )
-
- buildFilterAndProject(newScanRelation, projectList, filters, variants, attributeMap)
- }
-
- /**
- * Build the final Project(Filter(relation)) plan with rewritten expressions.
- */
- private def buildFilterAndProject(
- relation: LogicalPlan,
- projectList: Seq[NamedExpression],
- filters: Seq[Expression],
- variants: VariantInRelation,
- attributeMap: Map[ExprId, AttributeReference]): LogicalPlan = {
val withFilter = if (filters.nonEmpty) {
- Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), relation)
+ Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), newRelation)
} else {
- relation
+ newRelation
}
val newProjectList = projectList.map { e =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 31a98e1ff96c..963551e01548 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -22,7 +22,7 @@ import java.util.Locale
import scala.collection.mutable
import org.apache.spark.internal.LogKeys.{AGGREGATE_FUNCTIONS, COLUMN_NAMES, GROUP_BY_EXPRS, JOIN_CONDITION, JOIN_TYPE, POST_SCAN_FILTERS, PUSHED_FILTERS, RELATION_NAME, RELATION_OUTPUT}
-import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, ExprId, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ScanOperation}
@@ -32,10 +32,10 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.filter.Predicate
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownJoin, V1Scan}
-import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownJoin, SupportsPushDownVariants, V1Scan, VariantAccessInfo}
+import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, VariantInRelation, VariantMetadata}
import org.apache.spark.sql.sources
-import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, StructType}
+import org.apache.spark.sql.types.{BooleanType, DataType, DecimalType, IntegerType, StructField, StructType}
import org.apache.spark.sql.util.SchemaUtils._
import org.apache.spark.util.ArrayImplicits._
@@ -49,9 +49,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
pushDownFilters,
pushDownJoin,
pushDownAggregates,
+ pushDownVariants,
pushDownLimitAndOffset,
buildScanWithPushedAggregate,
buildScanWithPushedJoin,
+ buildScanWithPushedVariants,
pruneColumns)
pushdownRules.foldLeft(plan) { (newPlan, pushDownRule) =>
@@ -318,6 +320,97 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
case agg: Aggregate => rewriteAggregate(agg)
}
+ def pushDownVariants(plan: LogicalPlan): LogicalPlan = plan.transformDown {
+ case p@PhysicalOperation(projectList, filters, sHolder @ ScanBuilderHolder(_, _,
+ builder: SupportsPushDownVariants))
+ if conf.getConf(org.apache.spark.sql.internal.SQLConf.PUSH_VARIANT_INTO_SCAN) =>
+ pushVariantAccess(p, projectList, filters, sHolder, builder)
+ }
+
+ private def pushVariantAccess(
+ originalPlan: LogicalPlan,
+ projectList: Seq[NamedExpression],
+ filters: Seq[Expression],
+ sHolder: ScanBuilderHolder,
+ builder: SupportsPushDownVariants): LogicalPlan = {
+ val variants = new VariantInRelation
+
+ // Extract schema attributes from scan builder holder
+ val schemaAttributes = sHolder.output
+
+ // Construct schema for default value resolution
+ val structSchema = StructType(schemaAttributes.map(a =>
+ StructField(a.name, a.dataType, a.nullable, a.metadata)))
+
+ val defaultValues = org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.
+ existenceDefaultValues(structSchema)
+
+ // Add variant fields from the V2 scan schema
+ for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
+ variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
+ }
+ if (variants.mapping.isEmpty) return originalPlan
+
+ // Collect requested fields from project list and filters
+ projectList.foreach(variants.collectRequestedFields)
+ filters.foreach(variants.collectRequestedFields)
+
+ // If no variant columns remain after collection, return original plan
+ if (variants.mapping.forall(_._2.isEmpty)) return originalPlan
+
+ // Build VariantAccessInfo array for the API
+ val variantAccessInfoArray = schemaAttributes.flatMap { attr =>
+ variants.mapping.get(attr.exprId).flatMap(_.get(Nil)).map { fields =>
+ // Build extracted schema for this variant column
+ val extractedFields = fields.toArray.sortBy(_._2).map { case (field, ordinal) =>
+ StructField(ordinal.toString, field.targetType, metadata = field.path.toMetadata)
+ }
+ val extractedSchema = if (extractedFields.isEmpty) {
+ // Add placeholder field to avoid empty struct
+ val placeholder = VariantMetadata("$.__placeholder_field__",
+ failOnError = false, timeZoneId = "UTC")
+ StructType(Array(StructField("0", BooleanType, metadata = placeholder.toMetadata)))
+ } else {
+ StructType(extractedFields)
+ }
+ new VariantAccessInfo(attr.name, extractedSchema)
+ }
+ }.toArray
+
+ // Call the API to push down variant access
+ if (variantAccessInfoArray.isEmpty) return originalPlan
+
+ val pushed = builder.pushVariantAccess(variantAccessInfoArray)
+ if (!pushed) return originalPlan
+
+ // Get what was actually pushed
+ val pushedVariantAccess = builder.pushedVariantAccess()
+ if (pushedVariantAccess.isEmpty) return originalPlan
+
+ // Build new attribute mapping based on pushed variant access
+ val pushedColumnNames = pushedVariantAccess.map(_.columnName()).toSet
+ val attributeMap = schemaAttributes.map { a =>
+ if (pushedColumnNames.contains(a.name) && variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
+ val newType = variants.rewriteType(a.exprId, a.dataType, Nil)
+ val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)(
+ qualifier = a.qualifier)
+ (a.exprId, newAttr)
+ } else {
+ (a.exprId, a.asInstanceOf[AttributeReference])
+ }
+ }.toMap
+
+ val newOutput = sHolder.output.map(a => attributeMap.getOrElse(a.exprId, a))
+
+ // Store the transformation info on the holder for later use
+ sHolder.pushedVariants = Some(variants)
+ sHolder.pushedVariantAttributeMap = attributeMap
+ sHolder.output = newOutput
+
+ // Return the original plan unchanged - transformation happens in buildScanWithPushedVariants
+ originalPlan
+ }
+
private def rewriteAggregate(agg: Aggregate): LogicalPlan = agg.child match {
case PhysicalOperation(project, Nil, holder @ ScanBuilderHolder(_, _,
r: SupportsPushDownAggregates)) if CollapseProject.canCollapseExpressions(
@@ -589,6 +682,48 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
Project(projectList, scanRelation)
}
+ def buildScanWithPushedVariants(plan: LogicalPlan): LogicalPlan = plan.transform {
+ case p@PhysicalOperation(projectList, filters, holder: ScanBuilderHolder)
+ if holder.pushedVariants.isDefined =>
+ val variants = holder.pushedVariants.get
+ val attributeMap = holder.pushedVariantAttributeMap
+
+ // Build the scan
+ val scan = holder.builder.build()
+ val realOutput = toAttributes(scan.readSchema())
+ val wrappedScan = getWrappedScan(scan, holder)
+ val scanRelation = DataSourceV2ScanRelation(holder.relation, wrappedScan, realOutput)
+
+ // Create projection to map real output to expected output (with transformed types)
+ val outputProjection = realOutput.zip(holder.output).map { case (realAttr, expectedAttr) =>
+ Alias(realAttr, expectedAttr.name)(expectedAttr.exprId)
+ }
+
+ // Rewrite filter expressions using the variant transformation
+ val rewrittenFilters = if (filters.nonEmpty) {
+ val rewrittenFilterExprs = filters.map(variants.rewriteExpr(_, attributeMap))
+ Some(rewrittenFilterExprs.reduce(And))
+ } else {
+ None
+ }
+
+ // Rewrite project list expressions using the variant transformation
+ val rewrittenProjectList = projectList.map { e =>
+ val rewritten = variants.rewriteExpr(e, attributeMap)
+ rewritten match {
+ case n: NamedExpression => n
+ // When the variant column is directly selected, we replace the attribute
+ // reference with a struct access, which is not a NamedExpression. Wrap it with Alias.
+ case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier)
+ }
+ }
+
+ // Build the plan: Project(outputProjection) -> [Filter?] -> scanRelation
+ val withProjection = Project(outputProjection, scanRelation)
+ val withFilter = rewrittenFilters.map(Filter(_, withProjection)).getOrElse(withProjection)
+ Project(rewrittenProjectList, withFilter)
+ }
+
def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform {
case ScanOperation(project, filtersStayUp, filtersPushDown, sHolder: ScanBuilderHolder) =>
// column pruning
@@ -834,6 +969,10 @@ case class ScanBuilderHolder(
var joinedRelationsPushedDownOperators: Seq[PushedDownOperators] = Seq.empty[PushedDownOperators]
var pushedJoinOutputMap: AttributeMap[Expression] = AttributeMap.empty[Expression]
+
+ var pushedVariantAttributeMap: Map[ExprId, AttributeReference] = Map.empty
+
+ var pushedVariants: Option[VariantInRelation] = None
}
// A wrapper for v1 scan to carry the translated filters and the handled ones, along with
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
index d347cb04f0bc..68c6cb1beeb1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
@@ -25,7 +25,7 @@ import org.apache.parquet.hadoop.ParquetInputFormat
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
-import org.apache.spark.sql.connector.read.{PartitionReaderFactory, SupportsPushDownVariants, VariantAccessInfo}
+import org.apache.spark.sql.connector.read.{PartitionReaderFactory, VariantAccessInfo}
import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport}
import org.apache.spark.sql.execution.datasources.v2.FileScan
@@ -48,8 +48,7 @@ case class ParquetScan(
pushedAggregate: Option[Aggregation] = None,
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty,
- pushedVariantAccessInfo: Array[VariantAccessInfo] = Array.empty) extends FileScan
- with SupportsPushDownVariants {
+ pushedVariantAccessInfo: Array[VariantAccessInfo] = Array.empty) extends FileScan {
override def isSplitable(path: Path): Boolean = {
// If aggregate is pushed down, only the file footer will be read once,
// so file should not be split across multiple tasks.
@@ -58,11 +57,11 @@ case class ParquetScan(
// Build transformed schema if variant pushdown is active
private def effectiveReadDataSchema: StructType = {
- if (_pushedVariantAccess.isEmpty) {
+ if (pushedVariantAccessInfo.isEmpty) {
readDataSchema
} else {
// Build a mapping from column name to extracted schema
- val variantSchemaMap = _pushedVariantAccess.map(info =>
+ val variantSchemaMap = pushedVariantAccessInfo.map(info =>
info.columnName() -> info.extractedSchema()).toMap
// Transform the read data schema by replacing variant columns with their extracted schemas
@@ -84,10 +83,10 @@ case class ParquetScan(
// super.readSchema() combines readDataSchema + readPartitionSchema
// Apply variant transformation if variant pushdown is active
val baseSchema = super.readSchema()
- if (_pushedVariantAccess.isEmpty) {
+ if (pushedVariantAccessInfo.isEmpty) {
baseSchema
} else {
- val variantSchemaMap = _pushedVariantAccess.map(info =>
+ val variantSchemaMap = pushedVariantAccessInfo.map(info =>
info.columnName() -> info.extractedSchema()).toMap
StructType(baseSchema.map { field =>
variantSchemaMap.get(field.name) match {
@@ -99,23 +98,6 @@ case class ParquetScan(
}
}
- // SupportsPushDownVariants API implementation
- private var _pushedVariantAccess: Array[VariantAccessInfo] = pushedVariantAccessInfo
-
- override def pushVariantAccess(variantAccessInfo: Array[VariantAccessInfo]): Boolean = {
- // Parquet supports variant pushdown for all variant accesses
- if (variantAccessInfo.nonEmpty) {
- _pushedVariantAccess = variantAccessInfo
- true
- } else {
- false
- }
- }
-
- override def pushedVariantAccess(): Array[VariantAccessInfo] = {
- _pushedVariantAccess
- }
-
override def createReaderFactory(): PartitionReaderFactory = {
val effectiveSchema = effectiveReadDataSchema
val readDataSchemaAsJson = effectiveSchema.json
@@ -171,8 +153,8 @@ case class ParquetScan(
pushedAggregate.isEmpty && p.pushedAggregate.isEmpty
}
val pushedVariantEqual =
- java.util.Arrays.equals(_pushedVariantAccess.asInstanceOf[Array[Object]],
- p._pushedVariantAccess.asInstanceOf[Array[Object]])
+ java.util.Arrays.equals(pushedVariantAccessInfo.asInstanceOf[Array[Object]],
+ p.pushedVariantAccessInfo.asInstanceOf[Array[Object]])
super.equals(p) && dataSchema == p.dataSchema && options == p.options &&
equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual &&
pushedVariantEqual
@@ -189,8 +171,8 @@ case class ParquetScan(
}
override def getMetaData(): Map[String, String] = {
- val variantAccessStr = if (_pushedVariantAccess.nonEmpty) {
- _pushedVariantAccess.map(info =>
+ val variantAccessStr = if (pushedVariantAccessInfo.nonEmpty) {
+ pushedVariantAccessInfo.map(info =>
s"${info.columnName()}->${info.extractedSchema()}").mkString("[", ", ", "]")
} else {
"[]"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
index 01367675e65b..a8f993f20923 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
@@ -22,7 +22,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
-import org.apache.spark.sql.connector.read.SupportsPushDownAggregates
+import org.apache.spark.sql.connector.read.{SupportsPushDownAggregates, SupportsPushDownVariants, VariantAccessInfo}
import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter}
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
@@ -39,7 +39,8 @@ case class ParquetScanBuilder(
dataSchema: StructType,
options: CaseInsensitiveStringMap)
extends FileScanBuilder(sparkSession, fileIndex, dataSchema)
- with SupportsPushDownAggregates {
+ with SupportsPushDownAggregates
+ with SupportsPushDownVariants {
lazy val hadoopConf = {
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
// Hadoop Configurations are case sensitive.
@@ -50,6 +51,8 @@ case class ParquetScanBuilder(
private var pushedAggregations = Option.empty[Aggregation]
+ private var pushedVariantAccessInfo = Array.empty[VariantAccessInfo]
+
override protected val supportsNestedSchemaPruning: Boolean = true
override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = {
@@ -99,6 +102,21 @@ case class ParquetScanBuilder(
}
}
+ // SupportsPushDownVariants API implementation
+ override def pushVariantAccess(variantAccessInfo: Array[VariantAccessInfo]): Boolean = {
+ // Parquet supports variant pushdown for all variant accesses
+ if (variantAccessInfo.nonEmpty) {
+ pushedVariantAccessInfo = variantAccessInfo
+ true
+ } else {
+ false
+ }
+ }
+
+ override def pushedVariantAccess(): Array[VariantAccessInfo] = {
+ pushedVariantAccessInfo
+ }
+
override def build(): ParquetScan = {
// the `finalSchema` is either pruned in pushAggregation (if aggregates are
// pushed down), or pruned in readDataSchema() (in regular column pruning). These
@@ -108,6 +126,6 @@ case class ParquetScanBuilder(
}
ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema,
readPartitionSchema(), pushedDataFilters, options, pushedAggregations,
- partitionFilters, dataFilters)
+ partitionFilters, dataFilters, pushedVariantAccessInfo)
}
}