From b2683800a5128fc9cf26ef8b9fea2bc680619691 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 2 Dec 2025 01:02:04 -0800 Subject: [PATCH 1/2] Refactor SupportsPushDownVariants to be a ScanBuilder mix-in --- .../read/SupportsPushDownVariants.java | 9 ++-- .../datasources/PushVariantIntoScan.scala | 43 +++++++++---------- .../datasources/v2/parquet/ParquetScan.scala | 38 +++++----------- .../v2/parquet/ParquetScanBuilder.scala | 24 +++++++++-- 4 files changed, 57 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java index ff82e71bfd58..0fe47683bb0b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java @@ -20,7 +20,7 @@ import org.apache.spark.annotation.Evolving; /** - * A mix-in interface for {@link Scan}. Data sources can implement this interface to + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to * support pushing down variant field access operations to the data source. *

* When variant columns are accessed with specific field extractions (e.g., variant_get), @@ -31,16 +31,17 @@ *

    *
  1. Optimizer analyzes the query plan and identifies variant field accesses
  2. *
  3. Optimizer calls {@link #pushVariantAccess} with the access information
  4. - *
  5. Data source validates and stores the variant access information
  6. + *
  7. ScanBuilder validates and stores the variant access information
  8. *
  9. Optimizer retrieves pushed information via {@link #pushedVariantAccess}
  10. - *
  11. Data source uses the information to optimize reading in {@link #readSchema()} + *
  12. ScanBuilder builds a {@link Scan} with the variant access information
  13. + *
  14. Scan uses the information to optimize reading in {@link Scan#readSchema()} * and readers
  15. *
* * @since 4.1.0 */ @Evolving -public interface SupportsPushDownVariants extends Scan { +public interface SupportsPushDownVariants extends ScanBuilder { /** * Pushes down variant field access information to the data source. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala index 2cf1a5e9b8cd..65122df6798a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns import org.apache.spark.sql.connector.read.{SupportsPushDownVariants, VariantAccessInfo} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.execution.datasources.v2.ScanBuilderHolder import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -283,9 +283,8 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { rewritePlan(p, projectList, filters, relation, hadoopFsRelation) case p@PhysicalOperation(projectList, filters, - scanRelation @ DataSourceV2ScanRelation( - relation, scan: SupportsPushDownVariants, output, _, _)) => - rewritePlanV2(p, projectList, filters, scanRelation, scan) + sHolder @ ScanBuilderHolder(_, _, builder: SupportsPushDownVariants)) => + rewritePlanV2(p, projectList, filters, sHolder, builder) } } @@ -338,21 +337,22 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { // DataSource V2 rewrite method using SupportsPushDownVariants API // Key differences from V1 implementation: - // 1. V2 uses DataSourceV2ScanRelation instead of LogicalRelation - // 2. Uses SupportsPushDownVariants API instead of directly manipulating scan - // 3. Schema is already resolved in scanRelation.output (no need for relation.resolve()) - // 4. Scan rebuilding is handled by the scan implementation via the API - // Data sources like Delta and Iceberg can implement this API to support variant pushdown. + // 1. V2 uses ScanBuilderHolder instead of LogicalRelation + // 2. Uses SupportsPushDownVariants API on ScanBuilder instead of directly manipulating scan + // 3. Schema is already resolved in sHolder.output (no need for relation.resolve()) + // 4. Scan will be built later with the variant access information + // Data sources like Parquet (V2), Delta and Iceberg can implement this API to support + // variant pushdown. private def rewritePlanV2( originalPlan: LogicalPlan, projectList: Seq[NamedExpression], filters: Seq[Expression], - scanRelation: DataSourceV2ScanRelation, - scan: SupportsPushDownVariants): LogicalPlan = { + sHolder: ScanBuilderHolder, + builder: SupportsPushDownVariants): LogicalPlan = { val variants = new VariantInRelation - // Extract schema attributes from V2 scan relation - val schemaAttributes = scanRelation.output + // Extract schema attributes from scan builder holder + val schemaAttributes = sHolder.output // Construct schema for default value resolution val structSchema = StructType(schemaAttributes.map(a => @@ -395,11 +395,11 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { // Call the API to push down variant access if (variantAccessInfoArray.isEmpty) return originalPlan - val pushed = scan.pushVariantAccess(variantAccessInfoArray) + val pushed = builder.pushVariantAccess(variantAccessInfoArray) if (!pushed) return originalPlan // Get what was actually pushed - val pushedVariantAccess = scan.pushedVariantAccess() + val pushedVariantAccess = builder.pushedVariantAccess() if (pushedVariantAccess.isEmpty) return originalPlan // Build new attribute mapping based on pushed variant access @@ -415,15 +415,14 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { } }.toMap - val newOutput = scanRelation.output.map(a => attributeMap.getOrElse(a.exprId, a)) + val newOutput = sHolder.output.map(a => attributeMap.getOrElse(a.exprId, a)) - // The scan implementation should have updated its readSchema() based on the pushed info - // We just need to create a new scan relation with the updated output - val newScanRelation = scanRelation.copy( - output = newOutput - ) + // Update the scan builder holder's output with the new schema + // The scan will be built later with the variant access information already pushed to the + // builder + sHolder.output = newOutput - buildFilterAndProject(newScanRelation, projectList, filters, variants, attributeMap) + buildFilterAndProject(sHolder, projectList, filters, variants, attributeMap) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index d347cb04f0bc..68c6cb1beeb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -25,7 +25,7 @@ import org.apache.parquet.hadoop.ParquetInputFormat import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{PartitionReaderFactory, SupportsPushDownVariants, VariantAccessInfo} +import org.apache.spark.sql.connector.read.{PartitionReaderFactory, VariantAccessInfo} import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} import org.apache.spark.sql.execution.datasources.v2.FileScan @@ -48,8 +48,7 @@ case class ParquetScan( pushedAggregate: Option[Aggregation] = None, partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty, - pushedVariantAccessInfo: Array[VariantAccessInfo] = Array.empty) extends FileScan - with SupportsPushDownVariants { + pushedVariantAccessInfo: Array[VariantAccessInfo] = Array.empty) extends FileScan { override def isSplitable(path: Path): Boolean = { // If aggregate is pushed down, only the file footer will be read once, // so file should not be split across multiple tasks. @@ -58,11 +57,11 @@ case class ParquetScan( // Build transformed schema if variant pushdown is active private def effectiveReadDataSchema: StructType = { - if (_pushedVariantAccess.isEmpty) { + if (pushedVariantAccessInfo.isEmpty) { readDataSchema } else { // Build a mapping from column name to extracted schema - val variantSchemaMap = _pushedVariantAccess.map(info => + val variantSchemaMap = pushedVariantAccessInfo.map(info => info.columnName() -> info.extractedSchema()).toMap // Transform the read data schema by replacing variant columns with their extracted schemas @@ -84,10 +83,10 @@ case class ParquetScan( // super.readSchema() combines readDataSchema + readPartitionSchema // Apply variant transformation if variant pushdown is active val baseSchema = super.readSchema() - if (_pushedVariantAccess.isEmpty) { + if (pushedVariantAccessInfo.isEmpty) { baseSchema } else { - val variantSchemaMap = _pushedVariantAccess.map(info => + val variantSchemaMap = pushedVariantAccessInfo.map(info => info.columnName() -> info.extractedSchema()).toMap StructType(baseSchema.map { field => variantSchemaMap.get(field.name) match { @@ -99,23 +98,6 @@ case class ParquetScan( } } - // SupportsPushDownVariants API implementation - private var _pushedVariantAccess: Array[VariantAccessInfo] = pushedVariantAccessInfo - - override def pushVariantAccess(variantAccessInfo: Array[VariantAccessInfo]): Boolean = { - // Parquet supports variant pushdown for all variant accesses - if (variantAccessInfo.nonEmpty) { - _pushedVariantAccess = variantAccessInfo - true - } else { - false - } - } - - override def pushedVariantAccess(): Array[VariantAccessInfo] = { - _pushedVariantAccess - } - override def createReaderFactory(): PartitionReaderFactory = { val effectiveSchema = effectiveReadDataSchema val readDataSchemaAsJson = effectiveSchema.json @@ -171,8 +153,8 @@ case class ParquetScan( pushedAggregate.isEmpty && p.pushedAggregate.isEmpty } val pushedVariantEqual = - java.util.Arrays.equals(_pushedVariantAccess.asInstanceOf[Array[Object]], - p._pushedVariantAccess.asInstanceOf[Array[Object]]) + java.util.Arrays.equals(pushedVariantAccessInfo.asInstanceOf[Array[Object]], + p.pushedVariantAccessInfo.asInstanceOf[Array[Object]]) super.equals(p) && dataSchema == p.dataSchema && options == p.options && equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual && pushedVariantEqual @@ -189,8 +171,8 @@ case class ParquetScan( } override def getMetaData(): Map[String, String] = { - val variantAccessStr = if (_pushedVariantAccess.nonEmpty) { - _pushedVariantAccess.map(info => + val variantAccessStr = if (pushedVariantAccessInfo.nonEmpty) { + pushedVariantAccessInfo.map(info => s"${info.columnName()}->${info.extractedSchema()}").mkString("[", ", ", "]") } else { "[]" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 01367675e65b..a8f993f20923 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -22,7 +22,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.SupportsPushDownAggregates +import org.apache.spark.sql.connector.read.{SupportsPushDownAggregates, SupportsPushDownVariants, VariantAccessInfo} import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder @@ -39,7 +39,8 @@ case class ParquetScanBuilder( dataSchema: StructType, options: CaseInsensitiveStringMap) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) - with SupportsPushDownAggregates { + with SupportsPushDownAggregates + with SupportsPushDownVariants { lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -50,6 +51,8 @@ case class ParquetScanBuilder( private var pushedAggregations = Option.empty[Aggregation] + private var pushedVariantAccessInfo = Array.empty[VariantAccessInfo] + override protected val supportsNestedSchemaPruning: Boolean = true override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { @@ -99,6 +102,21 @@ case class ParquetScanBuilder( } } + // SupportsPushDownVariants API implementation + override def pushVariantAccess(variantAccessInfo: Array[VariantAccessInfo]): Boolean = { + // Parquet supports variant pushdown for all variant accesses + if (variantAccessInfo.nonEmpty) { + pushedVariantAccessInfo = variantAccessInfo + true + } else { + false + } + } + + override def pushedVariantAccess(): Array[VariantAccessInfo] = { + pushedVariantAccessInfo + } + override def build(): ParquetScan = { // the `finalSchema` is either pruned in pushAggregation (if aggregates are // pushed down), or pruned in readDataSchema() (in regular column pruning). These @@ -108,6 +126,6 @@ case class ParquetScanBuilder( } ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema, readPartitionSchema(), pushedDataFilters, options, pushedAggregations, - partitionFilters, dataFilters) + partitionFilters, dataFilters, pushedVariantAccessInfo) } } From 712bffffafbabadbc001f32b49d18524c1ece592 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 2 Dec 2025 14:57:31 -0800 Subject: [PATCH 2/2] fix --- .../datasources/PushVariantIntoScan.scala | 112 +------------ .../v2/V2ScanRelationPushDown.scala | 147 +++++++++++++++++- 2 files changed, 145 insertions(+), 114 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala index 65122df6798a..7ac50d738f04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala @@ -26,10 +26,8 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, Subquery} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns -import org.apache.spark.sql.connector.read.{SupportsPushDownVariants, VariantAccessInfo} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.execution.datasources.v2.ScanBuilderHolder import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -281,10 +279,6 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { relation @ LogicalRelationWithTable( hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _), _)) => rewritePlan(p, projectList, filters, relation, hadoopFsRelation) - - case p@PhysicalOperation(projectList, filters, - sHolder @ ScanBuilderHolder(_, _, builder: SupportsPushDownVariants)) => - rewritePlanV2(p, projectList, filters, sHolder, builder) } } @@ -332,112 +326,10 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { hadoopFsRelation.sparkSession) val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq) - buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap) - } - - // DataSource V2 rewrite method using SupportsPushDownVariants API - // Key differences from V1 implementation: - // 1. V2 uses ScanBuilderHolder instead of LogicalRelation - // 2. Uses SupportsPushDownVariants API on ScanBuilder instead of directly manipulating scan - // 3. Schema is already resolved in sHolder.output (no need for relation.resolve()) - // 4. Scan will be built later with the variant access information - // Data sources like Parquet (V2), Delta and Iceberg can implement this API to support - // variant pushdown. - private def rewritePlanV2( - originalPlan: LogicalPlan, - projectList: Seq[NamedExpression], - filters: Seq[Expression], - sHolder: ScanBuilderHolder, - builder: SupportsPushDownVariants): LogicalPlan = { - val variants = new VariantInRelation - - // Extract schema attributes from scan builder holder - val schemaAttributes = sHolder.output - - // Construct schema for default value resolution - val structSchema = StructType(schemaAttributes.map(a => - StructField(a.name, a.dataType, a.nullable, a.metadata))) - - val defaultValues = ResolveDefaultColumns.existenceDefaultValues(structSchema) - - // Add variant fields from the V2 scan schema - for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) { - variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil) - } - if (variants.mapping.isEmpty) return originalPlan - - // Collect requested fields from project list and filters - projectList.foreach(variants.collectRequestedFields) - filters.foreach(variants.collectRequestedFields) - - // If no variant columns remain after collection, return original plan - if (variants.mapping.forall(_._2.isEmpty)) return originalPlan - - // Build VariantAccessInfo array for the API - val variantAccessInfoArray = schemaAttributes.flatMap { attr => - variants.mapping.get(attr.exprId).flatMap(_.get(Nil)).map { fields => - // Build extracted schema for this variant column - val extractedFields = fields.toArray.sortBy(_._2).map { case (field, ordinal) => - StructField(ordinal.toString, field.targetType, metadata = field.path.toMetadata) - } - val extractedSchema = if (extractedFields.isEmpty) { - // Add placeholder field to avoid empty struct - val placeholder = VariantMetadata("$.__placeholder_field__", - failOnError = false, timeZoneId = "UTC") - StructType(Array(StructField("0", BooleanType, metadata = placeholder.toMetadata))) - } else { - StructType(extractedFields) - } - new VariantAccessInfo(attr.name, extractedSchema) - } - }.toArray - - // Call the API to push down variant access - if (variantAccessInfoArray.isEmpty) return originalPlan - - val pushed = builder.pushVariantAccess(variantAccessInfoArray) - if (!pushed) return originalPlan - - // Get what was actually pushed - val pushedVariantAccess = builder.pushedVariantAccess() - if (pushedVariantAccess.isEmpty) return originalPlan - - // Build new attribute mapping based on pushed variant access - val pushedColumnNames = pushedVariantAccess.map(_.columnName()).toSet - val attributeMap = schemaAttributes.map { a => - if (pushedColumnNames.contains(a.name) && variants.mapping.get(a.exprId).exists(_.nonEmpty)) { - val newType = variants.rewriteType(a.exprId, a.dataType, Nil) - val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)( - qualifier = a.qualifier) - (a.exprId, newAttr) - } else { - (a.exprId, a) - } - }.toMap - - val newOutput = sHolder.output.map(a => attributeMap.getOrElse(a.exprId, a)) - - // Update the scan builder holder's output with the new schema - // The scan will be built later with the variant access information already pushed to the - // builder - sHolder.output = newOutput - - buildFilterAndProject(sHolder, projectList, filters, variants, attributeMap) - } - - /** - * Build the final Project(Filter(relation)) plan with rewritten expressions. - */ - private def buildFilterAndProject( - relation: LogicalPlan, - projectList: Seq[NamedExpression], - filters: Seq[Expression], - variants: VariantInRelation, - attributeMap: Map[ExprId, AttributeReference]): LogicalPlan = { val withFilter = if (filters.nonEmpty) { - Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), relation) + Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), newRelation) } else { - relation + newRelation } val newProjectList = projectList.map { e => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 31a98e1ff96c..963551e01548 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -22,7 +22,7 @@ import java.util.Locale import scala.collection.mutable import org.apache.spark.internal.LogKeys.{AGGREGATE_FUNCTIONS, COLUMN_NAMES, GROUP_BY_EXPRS, JOIN_CONDITION, JOIN_TYPE, POST_SCAN_FILTERS, PUSHED_FILTERS, RELATION_NAME, RELATION_OUTPUT} -import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, ExprId, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ScanOperation} @@ -32,10 +32,10 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder} import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.connector.expressions.filter.Predicate -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownJoin, V1Scan} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownJoin, SupportsPushDownVariants, V1Scan, VariantAccessInfo} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, VariantInRelation, VariantMetadata} import org.apache.spark.sql.sources -import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, StructType} +import org.apache.spark.sql.types.{BooleanType, DataType, DecimalType, IntegerType, StructField, StructType} import org.apache.spark.sql.util.SchemaUtils._ import org.apache.spark.util.ArrayImplicits._ @@ -49,9 +49,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { pushDownFilters, pushDownJoin, pushDownAggregates, + pushDownVariants, pushDownLimitAndOffset, buildScanWithPushedAggregate, buildScanWithPushedJoin, + buildScanWithPushedVariants, pruneColumns) pushdownRules.foldLeft(plan) { (newPlan, pushDownRule) => @@ -318,6 +320,97 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case agg: Aggregate => rewriteAggregate(agg) } + def pushDownVariants(plan: LogicalPlan): LogicalPlan = plan.transformDown { + case p@PhysicalOperation(projectList, filters, sHolder @ ScanBuilderHolder(_, _, + builder: SupportsPushDownVariants)) + if conf.getConf(org.apache.spark.sql.internal.SQLConf.PUSH_VARIANT_INTO_SCAN) => + pushVariantAccess(p, projectList, filters, sHolder, builder) + } + + private def pushVariantAccess( + originalPlan: LogicalPlan, + projectList: Seq[NamedExpression], + filters: Seq[Expression], + sHolder: ScanBuilderHolder, + builder: SupportsPushDownVariants): LogicalPlan = { + val variants = new VariantInRelation + + // Extract schema attributes from scan builder holder + val schemaAttributes = sHolder.output + + // Construct schema for default value resolution + val structSchema = StructType(schemaAttributes.map(a => + StructField(a.name, a.dataType, a.nullable, a.metadata))) + + val defaultValues = org.apache.spark.sql.catalyst.util.ResolveDefaultColumns. + existenceDefaultValues(structSchema) + + // Add variant fields from the V2 scan schema + for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) { + variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil) + } + if (variants.mapping.isEmpty) return originalPlan + + // Collect requested fields from project list and filters + projectList.foreach(variants.collectRequestedFields) + filters.foreach(variants.collectRequestedFields) + + // If no variant columns remain after collection, return original plan + if (variants.mapping.forall(_._2.isEmpty)) return originalPlan + + // Build VariantAccessInfo array for the API + val variantAccessInfoArray = schemaAttributes.flatMap { attr => + variants.mapping.get(attr.exprId).flatMap(_.get(Nil)).map { fields => + // Build extracted schema for this variant column + val extractedFields = fields.toArray.sortBy(_._2).map { case (field, ordinal) => + StructField(ordinal.toString, field.targetType, metadata = field.path.toMetadata) + } + val extractedSchema = if (extractedFields.isEmpty) { + // Add placeholder field to avoid empty struct + val placeholder = VariantMetadata("$.__placeholder_field__", + failOnError = false, timeZoneId = "UTC") + StructType(Array(StructField("0", BooleanType, metadata = placeholder.toMetadata))) + } else { + StructType(extractedFields) + } + new VariantAccessInfo(attr.name, extractedSchema) + } + }.toArray + + // Call the API to push down variant access + if (variantAccessInfoArray.isEmpty) return originalPlan + + val pushed = builder.pushVariantAccess(variantAccessInfoArray) + if (!pushed) return originalPlan + + // Get what was actually pushed + val pushedVariantAccess = builder.pushedVariantAccess() + if (pushedVariantAccess.isEmpty) return originalPlan + + // Build new attribute mapping based on pushed variant access + val pushedColumnNames = pushedVariantAccess.map(_.columnName()).toSet + val attributeMap = schemaAttributes.map { a => + if (pushedColumnNames.contains(a.name) && variants.mapping.get(a.exprId).exists(_.nonEmpty)) { + val newType = variants.rewriteType(a.exprId, a.dataType, Nil) + val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)( + qualifier = a.qualifier) + (a.exprId, newAttr) + } else { + (a.exprId, a.asInstanceOf[AttributeReference]) + } + }.toMap + + val newOutput = sHolder.output.map(a => attributeMap.getOrElse(a.exprId, a)) + + // Store the transformation info on the holder for later use + sHolder.pushedVariants = Some(variants) + sHolder.pushedVariantAttributeMap = attributeMap + sHolder.output = newOutput + + // Return the original plan unchanged - transformation happens in buildScanWithPushedVariants + originalPlan + } + private def rewriteAggregate(agg: Aggregate): LogicalPlan = agg.child match { case PhysicalOperation(project, Nil, holder @ ScanBuilderHolder(_, _, r: SupportsPushDownAggregates)) if CollapseProject.canCollapseExpressions( @@ -589,6 +682,48 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { Project(projectList, scanRelation) } + def buildScanWithPushedVariants(plan: LogicalPlan): LogicalPlan = plan.transform { + case p@PhysicalOperation(projectList, filters, holder: ScanBuilderHolder) + if holder.pushedVariants.isDefined => + val variants = holder.pushedVariants.get + val attributeMap = holder.pushedVariantAttributeMap + + // Build the scan + val scan = holder.builder.build() + val realOutput = toAttributes(scan.readSchema()) + val wrappedScan = getWrappedScan(scan, holder) + val scanRelation = DataSourceV2ScanRelation(holder.relation, wrappedScan, realOutput) + + // Create projection to map real output to expected output (with transformed types) + val outputProjection = realOutput.zip(holder.output).map { case (realAttr, expectedAttr) => + Alias(realAttr, expectedAttr.name)(expectedAttr.exprId) + } + + // Rewrite filter expressions using the variant transformation + val rewrittenFilters = if (filters.nonEmpty) { + val rewrittenFilterExprs = filters.map(variants.rewriteExpr(_, attributeMap)) + Some(rewrittenFilterExprs.reduce(And)) + } else { + None + } + + // Rewrite project list expressions using the variant transformation + val rewrittenProjectList = projectList.map { e => + val rewritten = variants.rewriteExpr(e, attributeMap) + rewritten match { + case n: NamedExpression => n + // When the variant column is directly selected, we replace the attribute + // reference with a struct access, which is not a NamedExpression. Wrap it with Alias. + case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier) + } + } + + // Build the plan: Project(outputProjection) -> [Filter?] -> scanRelation + val withProjection = Project(outputProjection, scanRelation) + val withFilter = rewrittenFilters.map(Filter(_, withProjection)).getOrElse(withProjection) + Project(rewrittenProjectList, withFilter) + } + def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform { case ScanOperation(project, filtersStayUp, filtersPushDown, sHolder: ScanBuilderHolder) => // column pruning @@ -834,6 +969,10 @@ case class ScanBuilderHolder( var joinedRelationsPushedDownOperators: Seq[PushedDownOperators] = Seq.empty[PushedDownOperators] var pushedJoinOutputMap: AttributeMap[Expression] = AttributeMap.empty[Expression] + + var pushedVariantAttributeMap: Map[ExprId, AttributeReference] = Map.empty + + var pushedVariants: Option[VariantInRelation] = None } // A wrapper for v1 scan to carry the translated filters and the handled ones, along with