Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ class SparkOptimizer(
SchemaPruning,
GroupBasedRowLevelOperationScanPlanning,
V1Writes,
PushVariantIntoScan,
V2ScanRelationPushDown,
V2ScanPartitioningAndOrdering,
V2Writes,
PruneFileSourcePartitions)
PruneFileSourcePartitions,
PushVariantIntoScan)

override def preCBORules: Seq[Rule[LogicalPlan]] =
Seq(OptimizeMetadataOnlyDeleteFromTable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -280,8 +279,6 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
relation @ LogicalRelationWithTable(
hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _), _)) =>
rewritePlan(p, projectList, filters, relation, hadoopFsRelation)
case p@PhysicalOperation(projectList, filters, relation: DataSourceV2Relation) =>
rewriteV2RelationPlan(p, projectList, filters, relation)
}
}

Expand All @@ -291,91 +288,23 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
filters: Seq[Expression],
relation: LogicalRelation,
hadoopFsRelation: HadoopFsRelation): LogicalPlan = {
val variants = new VariantInRelation

val schemaAttributes = relation.resolve(hadoopFsRelation.dataSchema,
hadoopFsRelation.sparkSession.sessionState.analyzer.resolver)

// Collect variant fields from the relation output
val variants = collectAndRewriteVariants(schemaAttributes)
if (variants.mapping.isEmpty) return originalPlan

// Collect requested fields from projections and filters
projectList.foreach(variants.collectRequestedFields)
filters.foreach(variants.collectRequestedFields)
// `collectRequestedFields` may have removed all variant columns.
if (variants.mapping.forall(_._2.isEmpty)) return originalPlan

// Build attribute map with rewritten types
val attributeMap = buildAttributeMap(schemaAttributes, variants)

// Build new schema with variant types replaced by struct types
val newFields = schemaAttributes.map { a =>
val dataType = attributeMap(a.exprId).dataType
StructField(a.name, dataType, a.nullable, a.metadata)
val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType(
schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
}
// Update relation output attributes with new types
val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a))

// Update HadoopFsRelation's data schema so the file source reads the struct columns
val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema = StructType(newFields))(
hadoopFsRelation.sparkSession)
val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq)

// Build filter and project with rewritten expressions
buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap)
}

private def rewriteV2RelationPlan(
originalPlan: LogicalPlan,
projectList: Seq[NamedExpression],
filters: Seq[Expression],
relation: DataSourceV2Relation): LogicalPlan = {

// Collect variant fields from the relation output
val variants = collectAndRewriteVariants(relation.output)
if (variants.mapping.isEmpty) return originalPlan

// Collect requested fields from projections and filters
projectList.foreach(variants.collectRequestedFields)
filters.foreach(variants.collectRequestedFields)
// `collectRequestedFields` may have removed all variant columns.
if (variants.mapping.forall(_._2.isEmpty)) return originalPlan

// Build attribute map with rewritten types
val attributeMap = buildAttributeMap(relation.output, variants)

// Update relation output attributes with new types
// Note: DSv2 doesn't need to update the schema in the relation itself. The schema will be
// communicated to the data source later via V2ScanRelationPushDown.pruneColumns() API.
val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a))
val newRelation = relation.copy(output = newOutput.toIndexedSeq)

// Build filter and project with rewritten expressions
buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap)
}

/**
* Collect variant fields and return initialized VariantInRelation.
*/
private def collectAndRewriteVariants(
schemaAttributes: Seq[Attribute]): VariantInRelation = {
val variants = new VariantInRelation
val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType(
schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))

for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
}

variants
}

/**
* Build attribute map with rewritten variant types.
*/
private def buildAttributeMap(
schemaAttributes: Seq[Attribute],
variants: VariantInRelation): Map[ExprId, AttributeReference] = {
schemaAttributes.map { a =>
val attributeMap = schemaAttributes.map { a =>
if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
val newType = variants.rewriteType(a.exprId, a.dataType, Nil)
val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)(
Expand All @@ -387,24 +316,21 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
(a.exprId, a.asInstanceOf[AttributeReference])
}
}.toMap
}
val newFields = schemaAttributes.map { a =>
val dataType = attributeMap(a.exprId).dataType
StructField(a.name, dataType, a.nullable, a.metadata)
}
val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a))

/**
* Build the final Project(Filter(relation)) plan with rewritten expressions.
*/
private def buildFilterAndProject(
relation: LogicalPlan,
projectList: Seq[NamedExpression],
filters: Seq[Expression],
variants: VariantInRelation,
attributeMap: Map[ExprId, AttributeReference]): LogicalPlan = {
val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema = StructType(newFields))(
hadoopFsRelation.sparkSession)
val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq)

val withFilter = if (filters.nonEmpty) {
Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), relation)
Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), newRelation)
} else {
relation
newRelation
}

val newProjectList = projectList.map { e =>
val rewritten = variants.rewriteExpr(e, attributeMap)
rewritten match {
Expand All @@ -415,7 +341,6 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier)
}
}

Project(newProjectList, withFilter)
}
}

This file was deleted.