Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.apache.spark.annotation.Evolving;

/**
* A mix-in interface for {@link Scan}. Data sources can implement this interface to
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* support pushing down variant field access operations to the data source.
* <p>
* When variant columns are accessed with specific field extractions (e.g., variant_get),
Expand All @@ -31,16 +31,17 @@
* <ol>
* <li>Optimizer analyzes the query plan and identifies variant field accesses</li>
* <li>Optimizer calls {@link #pushVariantAccess} with the access information</li>
* <li>Data source validates and stores the variant access information</li>
* <li>ScanBuilder validates and stores the variant access information</li>
* <li>Optimizer retrieves pushed information via {@link #pushedVariantAccess}</li>
* <li>Data source uses the information to optimize reading in {@link #readSchema()}
* <li>ScanBuilder builds a {@link Scan} with the variant access information</li>
* <li>Scan uses the information to optimize reading in {@link Scan#readSchema()}
* and readers</li>
* </ol>
*
* @since 4.1.0
*/
@Evolving
public interface SupportsPushDownVariants extends Scan {
public interface SupportsPushDownVariants extends ScanBuilder {

/**
* Pushes down variant field access information to the data source.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, Subquery}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
import org.apache.spark.sql.connector.read.{SupportsPushDownVariants, VariantAccessInfo}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -281,11 +279,6 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
relation @ LogicalRelationWithTable(
hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _), _)) =>
rewritePlan(p, projectList, filters, relation, hadoopFsRelation)

case p@PhysicalOperation(projectList, filters,
scanRelation @ DataSourceV2ScanRelation(
relation, scan: SupportsPushDownVariants, output, _, _)) =>
rewritePlanV2(p, projectList, filters, scanRelation, scan)
}
}

Expand Down Expand Up @@ -333,112 +326,10 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
hadoopFsRelation.sparkSession)
val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq)

buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap)
}

// DataSource V2 rewrite method using SupportsPushDownVariants API
// Key differences from V1 implementation:
// 1. V2 uses DataSourceV2ScanRelation instead of LogicalRelation
// 2. Uses SupportsPushDownVariants API instead of directly manipulating scan
// 3. Schema is already resolved in scanRelation.output (no need for relation.resolve())
// 4. Scan rebuilding is handled by the scan implementation via the API
// Data sources like Delta and Iceberg can implement this API to support variant pushdown.
private def rewritePlanV2(
originalPlan: LogicalPlan,
projectList: Seq[NamedExpression],
filters: Seq[Expression],
scanRelation: DataSourceV2ScanRelation,
scan: SupportsPushDownVariants): LogicalPlan = {
val variants = new VariantInRelation

// Extract schema attributes from V2 scan relation
val schemaAttributes = scanRelation.output

// Construct schema for default value resolution
val structSchema = StructType(schemaAttributes.map(a =>
StructField(a.name, a.dataType, a.nullable, a.metadata)))

val defaultValues = ResolveDefaultColumns.existenceDefaultValues(structSchema)

// Add variant fields from the V2 scan schema
for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
}
if (variants.mapping.isEmpty) return originalPlan

// Collect requested fields from project list and filters
projectList.foreach(variants.collectRequestedFields)
filters.foreach(variants.collectRequestedFields)

// If no variant columns remain after collection, return original plan
if (variants.mapping.forall(_._2.isEmpty)) return originalPlan

// Build VariantAccessInfo array for the API
val variantAccessInfoArray = schemaAttributes.flatMap { attr =>
variants.mapping.get(attr.exprId).flatMap(_.get(Nil)).map { fields =>
// Build extracted schema for this variant column
val extractedFields = fields.toArray.sortBy(_._2).map { case (field, ordinal) =>
StructField(ordinal.toString, field.targetType, metadata = field.path.toMetadata)
}
val extractedSchema = if (extractedFields.isEmpty) {
// Add placeholder field to avoid empty struct
val placeholder = VariantMetadata("$.__placeholder_field__",
failOnError = false, timeZoneId = "UTC")
StructType(Array(StructField("0", BooleanType, metadata = placeholder.toMetadata)))
} else {
StructType(extractedFields)
}
new VariantAccessInfo(attr.name, extractedSchema)
}
}.toArray

// Call the API to push down variant access
if (variantAccessInfoArray.isEmpty) return originalPlan

val pushed = scan.pushVariantAccess(variantAccessInfoArray)
if (!pushed) return originalPlan

// Get what was actually pushed
val pushedVariantAccess = scan.pushedVariantAccess()
if (pushedVariantAccess.isEmpty) return originalPlan

// Build new attribute mapping based on pushed variant access
val pushedColumnNames = pushedVariantAccess.map(_.columnName()).toSet
val attributeMap = schemaAttributes.map { a =>
if (pushedColumnNames.contains(a.name) && variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
val newType = variants.rewriteType(a.exprId, a.dataType, Nil)
val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)(
qualifier = a.qualifier)
(a.exprId, newAttr)
} else {
(a.exprId, a)
}
}.toMap

val newOutput = scanRelation.output.map(a => attributeMap.getOrElse(a.exprId, a))

// The scan implementation should have updated its readSchema() based on the pushed info
// We just need to create a new scan relation with the updated output
val newScanRelation = scanRelation.copy(
output = newOutput
)

buildFilterAndProject(newScanRelation, projectList, filters, variants, attributeMap)
}

/**
* Build the final Project(Filter(relation)) plan with rewritten expressions.
*/
private def buildFilterAndProject(
relation: LogicalPlan,
projectList: Seq[NamedExpression],
filters: Seq[Expression],
variants: VariantInRelation,
attributeMap: Map[ExprId, AttributeReference]): LogicalPlan = {
val withFilter = if (filters.nonEmpty) {
Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), relation)
Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), newRelation)
} else {
relation
newRelation
}

val newProjectList = projectList.map { e =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.Locale
import scala.collection.mutable

import org.apache.spark.internal.LogKeys.{AGGREGATE_FUNCTIONS, COLUMN_NAMES, GROUP_BY_EXPRS, JOIN_CONDITION, JOIN_TYPE, POST_SCAN_FILTERS, PUSHED_FILTERS, RELATION_NAME, RELATION_OUTPUT}
import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, ExprId, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ScanOperation}
Expand All @@ -32,10 +32,10 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownJoin, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownJoin, SupportsPushDownVariants, V1Scan, VariantAccessInfo}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, VariantInRelation, VariantMetadata}
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, StructType}
import org.apache.spark.sql.types.{BooleanType, DataType, DecimalType, IntegerType, StructField, StructType}
import org.apache.spark.sql.util.SchemaUtils._
import org.apache.spark.util.ArrayImplicits._

Expand All @@ -49,9 +49,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
pushDownFilters,
pushDownJoin,
pushDownAggregates,
pushDownVariants,
pushDownLimitAndOffset,
buildScanWithPushedAggregate,
buildScanWithPushedJoin,
buildScanWithPushedVariants,
pruneColumns)

pushdownRules.foldLeft(plan) { (newPlan, pushDownRule) =>
Expand Down Expand Up @@ -318,6 +320,97 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
case agg: Aggregate => rewriteAggregate(agg)
}

def pushDownVariants(plan: LogicalPlan): LogicalPlan = plan.transformDown {
case p@PhysicalOperation(projectList, filters, sHolder @ ScanBuilderHolder(_, _,
builder: SupportsPushDownVariants))
if conf.getConf(org.apache.spark.sql.internal.SQLConf.PUSH_VARIANT_INTO_SCAN) =>
pushVariantAccess(p, projectList, filters, sHolder, builder)
}

private def pushVariantAccess(
originalPlan: LogicalPlan,
projectList: Seq[NamedExpression],
filters: Seq[Expression],
sHolder: ScanBuilderHolder,
builder: SupportsPushDownVariants): LogicalPlan = {
val variants = new VariantInRelation

// Extract schema attributes from scan builder holder
val schemaAttributes = sHolder.output

// Construct schema for default value resolution
val structSchema = StructType(schemaAttributes.map(a =>
StructField(a.name, a.dataType, a.nullable, a.metadata)))

val defaultValues = org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.
existenceDefaultValues(structSchema)

// Add variant fields from the V2 scan schema
for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
}
if (variants.mapping.isEmpty) return originalPlan

// Collect requested fields from project list and filters
projectList.foreach(variants.collectRequestedFields)
filters.foreach(variants.collectRequestedFields)

// If no variant columns remain after collection, return original plan
if (variants.mapping.forall(_._2.isEmpty)) return originalPlan

// Build VariantAccessInfo array for the API
val variantAccessInfoArray = schemaAttributes.flatMap { attr =>
variants.mapping.get(attr.exprId).flatMap(_.get(Nil)).map { fields =>
// Build extracted schema for this variant column
val extractedFields = fields.toArray.sortBy(_._2).map { case (field, ordinal) =>
StructField(ordinal.toString, field.targetType, metadata = field.path.toMetadata)
}
val extractedSchema = if (extractedFields.isEmpty) {
// Add placeholder field to avoid empty struct
val placeholder = VariantMetadata("$.__placeholder_field__",
failOnError = false, timeZoneId = "UTC")
StructType(Array(StructField("0", BooleanType, metadata = placeholder.toMetadata)))
} else {
StructType(extractedFields)
}
new VariantAccessInfo(attr.name, extractedSchema)
}
}.toArray

// Call the API to push down variant access
if (variantAccessInfoArray.isEmpty) return originalPlan

val pushed = builder.pushVariantAccess(variantAccessInfoArray)
if (!pushed) return originalPlan

// Get what was actually pushed
val pushedVariantAccess = builder.pushedVariantAccess()
if (pushedVariantAccess.isEmpty) return originalPlan

// Build new attribute mapping based on pushed variant access
val pushedColumnNames = pushedVariantAccess.map(_.columnName()).toSet
val attributeMap = schemaAttributes.map { a =>
if (pushedColumnNames.contains(a.name) && variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
val newType = variants.rewriteType(a.exprId, a.dataType, Nil)
val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)(
qualifier = a.qualifier)
(a.exprId, newAttr)
} else {
(a.exprId, a.asInstanceOf[AttributeReference])
}
}.toMap

val newOutput = sHolder.output.map(a => attributeMap.getOrElse(a.exprId, a))

// Store the transformation info on the holder for later use
sHolder.pushedVariants = Some(variants)
sHolder.pushedVariantAttributeMap = attributeMap
sHolder.output = newOutput

// Return the original plan unchanged - transformation happens in buildScanWithPushedVariants
originalPlan
}

private def rewriteAggregate(agg: Aggregate): LogicalPlan = agg.child match {
case PhysicalOperation(project, Nil, holder @ ScanBuilderHolder(_, _,
r: SupportsPushDownAggregates)) if CollapseProject.canCollapseExpressions(
Expand Down Expand Up @@ -589,6 +682,48 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
Project(projectList, scanRelation)
}

def buildScanWithPushedVariants(plan: LogicalPlan): LogicalPlan = plan.transform {
case p@PhysicalOperation(projectList, filters, holder: ScanBuilderHolder)
if holder.pushedVariants.isDefined =>
val variants = holder.pushedVariants.get
val attributeMap = holder.pushedVariantAttributeMap

// Build the scan
val scan = holder.builder.build()
val realOutput = toAttributes(scan.readSchema())
val wrappedScan = getWrappedScan(scan, holder)
val scanRelation = DataSourceV2ScanRelation(holder.relation, wrappedScan, realOutput)

// Create projection to map real output to expected output (with transformed types)
val outputProjection = realOutput.zip(holder.output).map { case (realAttr, expectedAttr) =>
Alias(realAttr, expectedAttr.name)(expectedAttr.exprId)
}

// Rewrite filter expressions using the variant transformation
val rewrittenFilters = if (filters.nonEmpty) {
val rewrittenFilterExprs = filters.map(variants.rewriteExpr(_, attributeMap))
Some(rewrittenFilterExprs.reduce(And))
} else {
None
}

// Rewrite project list expressions using the variant transformation
val rewrittenProjectList = projectList.map { e =>
val rewritten = variants.rewriteExpr(e, attributeMap)
rewritten match {
case n: NamedExpression => n
// When the variant column is directly selected, we replace the attribute
// reference with a struct access, which is not a NamedExpression. Wrap it with Alias.
case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier)
}
}

// Build the plan: Project(outputProjection) -> [Filter?] -> scanRelation
val withProjection = Project(outputProjection, scanRelation)
val withFilter = rewrittenFilters.map(Filter(_, withProjection)).getOrElse(withProjection)
Project(rewrittenProjectList, withFilter)
}

def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform {
case ScanOperation(project, filtersStayUp, filtersPushDown, sHolder: ScanBuilderHolder) =>
// column pruning
Expand Down Expand Up @@ -834,6 +969,10 @@ case class ScanBuilderHolder(
var joinedRelationsPushedDownOperators: Seq[PushedDownOperators] = Seq.empty[PushedDownOperators]

var pushedJoinOutputMap: AttributeMap[Expression] = AttributeMap.empty[Expression]

var pushedVariantAttributeMap: Map[ExprId, AttributeReference] = Map.empty

var pushedVariants: Option[VariantInRelation] = None
}

// A wrapper for v1 scan to carry the translated filters and the handled ones, along with
Expand Down
Loading