diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java new file mode 100644 index 000000000000..ff82e71bfd58 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; + +/** + * A mix-in interface for {@link Scan}. Data sources can implement this interface to + * support pushing down variant field access operations to the data source. + *
+ * When variant columns are accessed with specific field extractions (e.g., variant_get), + * the optimizer can push these accesses down to the data source. The data source can then + * read only the required fields from variant columns, reducing I/O and improving performance. + *
+ * The typical workflow is: + *
+ * Implementations should validate if the variant accesses can be pushed down based on + * the data source's capabilities. If some accesses cannot be pushed down, the implementation + * can choose to: + *
+ * The implementation should store the variant access information that can be pushed down. + * The stored information will be retrieved later via {@link #pushedVariantAccess()}. + * + * @param variantAccessInfo Array of variant access information, one per variant column + * @return true if at least some variant accesses were pushed down, false if none were pushed + */ + boolean pushVariantAccess(VariantAccessInfo[] variantAccessInfo); + + /** + * Returns the variant access information that has been pushed down to this scan. + *
+ * This method is called by the optimizer after {@link #pushVariantAccess} to retrieve + * what variant accesses were actually accepted by the data source. The optimizer uses + * this information to rewrite the query plan. + *
+ * If {@link #pushVariantAccess} was not called or returned false, this should return + * an empty array. + * + * @return Array of pushed down variant access information + */ + VariantAccessInfo[] pushedVariantAccess(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/VariantAccessInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/VariantAccessInfo.java new file mode 100644 index 000000000000..4f61a42d0519 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/VariantAccessInfo.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import java.io.Serializable; +import java.util.Objects; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.types.StructType; + +/** + * Variant access information that describes how variant fields are accessed in a query. + *
+ * This class captures the information needed by data sources to optimize reading variant columns. + * Instead of reading the entire variant value, the data source can read only the fields that + * are actually accessed, represented as a structured schema. + *
+ * For example, if a query accesses `variant_get(v, '$.a', 'int')` and + * `variant_get(v, '$.b', 'string')`, the extracted schema would be + * `struct<0:int, 1:string>` where field ordinals correspond to the access order. + * + * @since 4.1.0 + */ +@Evolving +public final class VariantAccessInfo implements Serializable { + private final String columnName; + private final StructType extractedSchema; + + /** + * Creates variant access information for a variant column. + * + * @param columnName The name of the variant column + * @param extractedSchema The schema representing extracted fields from the variant. + * Each field represents one variant field access, with field names + * typically being ordinals (e.g., "0", "1", "2") and metadata + * containing variant-specific information like JSON path. + */ + public VariantAccessInfo(String columnName, StructType extractedSchema) { + this.columnName = Objects.requireNonNull(columnName, "columnName cannot be null"); + this.extractedSchema = + Objects.requireNonNull(extractedSchema, "extractedSchema cannot be null"); + } + + /** + * Returns the name of the variant column. + */ + public String columnName() { + return columnName; + } + + /** + * Returns the schema representing fields extracted from the variant column. + *
+ * The schema structure is: + *
+ * Data sources should use this schema to determine what fields to extract from the variant
+ * and what types they should be converted to.
+ */
+ public StructType extractedSchema() {
+ return extractedSchema;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ VariantAccessInfo that = (VariantAccessInfo) o;
+ return columnName.equals(that.columnName) &&
+ extractedSchema.equals(that.extractedSchema);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(columnName, extractedSchema);
+ }
+
+ @Override
+ public String toString() {
+ return "VariantAccessInfo{" +
+ "columnName='" + columnName + '\'' +
+ ", extractedSchema=" + extractedSchema +
+ '}';
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 9699d8a2563f..8edb59f49282 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -40,11 +40,11 @@ class SparkOptimizer(
SchemaPruning,
GroupBasedRowLevelOperationScanPlanning,
V1Writes,
- PushVariantIntoScan,
V2ScanRelationPushDown,
V2ScanPartitioningAndOrdering,
V2Writes,
- PruneFileSourcePartitions)
+ PruneFileSourcePartitions,
+ PushVariantIntoScan)
override def preCBORules: Seq[Rule[LogicalPlan]] =
Seq(OptimizeMetadataOnlyDeleteFromTable)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
index 6ce53e3367c4..2cf1a5e9b8cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
@@ -26,9 +26,10 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, Subquery}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
+import org.apache.spark.sql.connector.read.{SupportsPushDownVariants, VariantAccessInfo}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -280,8 +281,11 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
relation @ LogicalRelationWithTable(
hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _), _)) =>
rewritePlan(p, projectList, filters, relation, hadoopFsRelation)
- case p@PhysicalOperation(projectList, filters, relation: DataSourceV2Relation) =>
- rewriteV2RelationPlan(p, projectList, filters, relation)
+
+ case p@PhysicalOperation(projectList, filters,
+ scanRelation @ DataSourceV2ScanRelation(
+ relation, scan: SupportsPushDownVariants, output, _, _)) =>
+ rewritePlanV2(p, projectList, filters, scanRelation, scan)
}
}
@@ -291,102 +295,135 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
filters: Seq[Expression],
relation: LogicalRelation,
hadoopFsRelation: HadoopFsRelation): LogicalPlan = {
+ val variants = new VariantInRelation
+
val schemaAttributes = relation.resolve(hadoopFsRelation.dataSchema,
hadoopFsRelation.sparkSession.sessionState.analyzer.resolver)
-
- // Collect variant fields from the relation output
- val variants = collectAndRewriteVariants(schemaAttributes)
+ val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType(
+ schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
+ for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
+ variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
+ }
if (variants.mapping.isEmpty) return originalPlan
- // Collect requested fields from projections and filters
projectList.foreach(variants.collectRequestedFields)
filters.foreach(variants.collectRequestedFields)
// `collectRequestedFields` may have removed all variant columns.
if (variants.mapping.forall(_._2.isEmpty)) return originalPlan
- // Build attribute map with rewritten types
- val attributeMap = buildAttributeMap(schemaAttributes, variants)
-
- // Build new schema with variant types replaced by struct types
+ val attributeMap = schemaAttributes.map { a =>
+ if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
+ val newType = variants.rewriteType(a.exprId, a.dataType, Nil)
+ val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)(
+ qualifier = a.qualifier)
+ (a.exprId, newAttr)
+ } else {
+ // `relation.resolve` actually returns `Seq[AttributeReference]`, although the return type
+ // is `Seq[Attribute]`.
+ (a.exprId, a.asInstanceOf[AttributeReference])
+ }
+ }.toMap
val newFields = schemaAttributes.map { a =>
val dataType = attributeMap(a.exprId).dataType
StructField(a.name, dataType, a.nullable, a.metadata)
}
- // Update relation output attributes with new types
val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a))
- // Update HadoopFsRelation's data schema so the file source reads the struct columns
val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema = StructType(newFields))(
hadoopFsRelation.sparkSession)
val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq)
- // Build filter and project with rewritten expressions
buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap)
}
- private def rewriteV2RelationPlan(
+ // DataSource V2 rewrite method using SupportsPushDownVariants API
+ // Key differences from V1 implementation:
+ // 1. V2 uses DataSourceV2ScanRelation instead of LogicalRelation
+ // 2. Uses SupportsPushDownVariants API instead of directly manipulating scan
+ // 3. Schema is already resolved in scanRelation.output (no need for relation.resolve())
+ // 4. Scan rebuilding is handled by the scan implementation via the API
+ // Data sources like Delta and Iceberg can implement this API to support variant pushdown.
+ private def rewritePlanV2(
originalPlan: LogicalPlan,
projectList: Seq[NamedExpression],
filters: Seq[Expression],
- relation: DataSourceV2Relation): LogicalPlan = {
+ scanRelation: DataSourceV2ScanRelation,
+ scan: SupportsPushDownVariants): LogicalPlan = {
+ val variants = new VariantInRelation
- // Collect variant fields from the relation output
- val variants = collectAndRewriteVariants(relation.output)
+ // Extract schema attributes from V2 scan relation
+ val schemaAttributes = scanRelation.output
+
+ // Construct schema for default value resolution
+ val structSchema = StructType(schemaAttributes.map(a =>
+ StructField(a.name, a.dataType, a.nullable, a.metadata)))
+
+ val defaultValues = ResolveDefaultColumns.existenceDefaultValues(structSchema)
+
+ // Add variant fields from the V2 scan schema
+ for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
+ variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
+ }
if (variants.mapping.isEmpty) return originalPlan
- // Collect requested fields from projections and filters
+ // Collect requested fields from project list and filters
projectList.foreach(variants.collectRequestedFields)
filters.foreach(variants.collectRequestedFields)
- // `collectRequestedFields` may have removed all variant columns.
+
+ // If no variant columns remain after collection, return original plan
if (variants.mapping.forall(_._2.isEmpty)) return originalPlan
- // Build attribute map with rewritten types
- val attributeMap = buildAttributeMap(relation.output, variants)
+ // Build VariantAccessInfo array for the API
+ val variantAccessInfoArray = schemaAttributes.flatMap { attr =>
+ variants.mapping.get(attr.exprId).flatMap(_.get(Nil)).map { fields =>
+ // Build extracted schema for this variant column
+ val extractedFields = fields.toArray.sortBy(_._2).map { case (field, ordinal) =>
+ StructField(ordinal.toString, field.targetType, metadata = field.path.toMetadata)
+ }
+ val extractedSchema = if (extractedFields.isEmpty) {
+ // Add placeholder field to avoid empty struct
+ val placeholder = VariantMetadata("$.__placeholder_field__",
+ failOnError = false, timeZoneId = "UTC")
+ StructType(Array(StructField("0", BooleanType, metadata = placeholder.toMetadata)))
+ } else {
+ StructType(extractedFields)
+ }
+ new VariantAccessInfo(attr.name, extractedSchema)
+ }
+ }.toArray
- // Update relation output attributes with new types
- // Note: DSv2 doesn't need to update the schema in the relation itself. The schema will be
- // communicated to the data source later via V2ScanRelationPushDown.pruneColumns() API.
- val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a))
- val newRelation = relation.copy(output = newOutput.toIndexedSeq)
+ // Call the API to push down variant access
+ if (variantAccessInfoArray.isEmpty) return originalPlan
- // Build filter and project with rewritten expressions
- buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap)
- }
+ val pushed = scan.pushVariantAccess(variantAccessInfoArray)
+ if (!pushed) return originalPlan
- /**
- * Collect variant fields and return initialized VariantInRelation.
- */
- private def collectAndRewriteVariants(
- schemaAttributes: Seq[Attribute]): VariantInRelation = {
- val variants = new VariantInRelation
- val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType(
- schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
+ // Get what was actually pushed
+ val pushedVariantAccess = scan.pushedVariantAccess()
+ if (pushedVariantAccess.isEmpty) return originalPlan
- for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
- variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
- }
-
- variants
- }
-
- /**
- * Build attribute map with rewritten variant types.
- */
- private def buildAttributeMap(
- schemaAttributes: Seq[Attribute],
- variants: VariantInRelation): Map[ExprId, AttributeReference] = {
- schemaAttributes.map { a =>
- if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
+ // Build new attribute mapping based on pushed variant access
+ val pushedColumnNames = pushedVariantAccess.map(_.columnName()).toSet
+ val attributeMap = schemaAttributes.map { a =>
+ if (pushedColumnNames.contains(a.name) && variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
val newType = variants.rewriteType(a.exprId, a.dataType, Nil)
val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)(
qualifier = a.qualifier)
(a.exprId, newAttr)
} else {
- // `relation.resolve` actually returns `Seq[AttributeReference]`, although the return type
- // is `Seq[Attribute]`.
- (a.exprId, a.asInstanceOf[AttributeReference])
+ (a.exprId, a)
}
}.toMap
+
+ val newOutput = scanRelation.output.map(a => attributeMap.getOrElse(a.exprId, a))
+
+ // The scan implementation should have updated its readSchema() based on the pushed info
+ // We just need to create a new scan relation with the updated output
+ val newScanRelation = scanRelation.copy(
+ output = newOutput
+ )
+
+ buildFilterAndProject(newScanRelation, projectList, filters, variants, attributeMap)
}
/**
@@ -398,7 +435,6 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
filters: Seq[Expression],
variants: VariantInRelation,
attributeMap: Map[ExprId, AttributeReference]): LogicalPlan = {
-
val withFilter = if (filters.nonEmpty) {
Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), relation)
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
index ec41e746469d..d347cb04f0bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
@@ -25,7 +25,7 @@ import org.apache.parquet.hadoop.ParquetInputFormat
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
-import org.apache.spark.sql.connector.read.PartitionReaderFactory
+import org.apache.spark.sql.connector.read.{PartitionReaderFactory, SupportsPushDownVariants, VariantAccessInfo}
import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport}
import org.apache.spark.sql.execution.datasources.v2.FileScan
@@ -47,21 +47,78 @@ case class ParquetScan(
options: CaseInsensitiveStringMap,
pushedAggregate: Option[Aggregation] = None,
partitionFilters: Seq[Expression] = Seq.empty,
- dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
+ dataFilters: Seq[Expression] = Seq.empty,
+ pushedVariantAccessInfo: Array[VariantAccessInfo] = Array.empty) extends FileScan
+ with SupportsPushDownVariants {
override def isSplitable(path: Path): Boolean = {
// If aggregate is pushed down, only the file footer will be read once,
// so file should not be split across multiple tasks.
pushedAggregate.isEmpty
}
+ // Build transformed schema if variant pushdown is active
+ private def effectiveReadDataSchema: StructType = {
+ if (_pushedVariantAccess.isEmpty) {
+ readDataSchema
+ } else {
+ // Build a mapping from column name to extracted schema
+ val variantSchemaMap = _pushedVariantAccess.map(info =>
+ info.columnName() -> info.extractedSchema()).toMap
+
+ // Transform the read data schema by replacing variant columns with their extracted schemas
+ StructType(readDataSchema.map { field =>
+ variantSchemaMap.get(field.name) match {
+ case Some(extractedSchema) => field.copy(dataType = extractedSchema)
+ case None => field
+ }
+ })
+ }
+ }
+
override def readSchema(): StructType = {
// If aggregate is pushed down, schema has already been pruned in `ParquetScanBuilder`
// and no need to call super.readSchema()
- if (pushedAggregate.nonEmpty) readDataSchema else super.readSchema()
+ if (pushedAggregate.nonEmpty) {
+ effectiveReadDataSchema
+ } else {
+ // super.readSchema() combines readDataSchema + readPartitionSchema
+ // Apply variant transformation if variant pushdown is active
+ val baseSchema = super.readSchema()
+ if (_pushedVariantAccess.isEmpty) {
+ baseSchema
+ } else {
+ val variantSchemaMap = _pushedVariantAccess.map(info =>
+ info.columnName() -> info.extractedSchema()).toMap
+ StructType(baseSchema.map { field =>
+ variantSchemaMap.get(field.name) match {
+ case Some(extractedSchema) => field.copy(dataType = extractedSchema)
+ case None => field
+ }
+ })
+ }
+ }
+ }
+
+ // SupportsPushDownVariants API implementation
+ private var _pushedVariantAccess: Array[VariantAccessInfo] = pushedVariantAccessInfo
+
+ override def pushVariantAccess(variantAccessInfo: Array[VariantAccessInfo]): Boolean = {
+ // Parquet supports variant pushdown for all variant accesses
+ if (variantAccessInfo.nonEmpty) {
+ _pushedVariantAccess = variantAccessInfo
+ true
+ } else {
+ false
+ }
+ }
+
+ override def pushedVariantAccess(): Array[VariantAccessInfo] = {
+ _pushedVariantAccess
}
override def createReaderFactory(): PartitionReaderFactory = {
- val readDataSchemaAsJson = readDataSchema.json
+ val effectiveSchema = effectiveReadDataSchema
+ val readDataSchemaAsJson = effectiveSchema.json
hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName)
hadoopConf.set(
ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA,
@@ -99,7 +156,7 @@ case class ParquetScan(
conf,
broadcastedConf,
dataSchema,
- readDataSchema,
+ effectiveSchema,
readPartitionSchema,
pushedFilters,
pushedAggregate,
@@ -113,8 +170,12 @@ case class ParquetScan(
} else {
pushedAggregate.isEmpty && p.pushedAggregate.isEmpty
}
+ val pushedVariantEqual =
+ java.util.Arrays.equals(_pushedVariantAccess.asInstanceOf[Array[Object]],
+ p._pushedVariantAccess.asInstanceOf[Array[Object]])
super.equals(p) && dataSchema == p.dataSchema && options == p.options &&
- equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual
+ equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual &&
+ pushedVariantEqual
case _ => false
}
@@ -128,8 +189,15 @@ case class ParquetScan(
}
override def getMetaData(): Map[String, String] = {
+ val variantAccessStr = if (_pushedVariantAccess.nonEmpty) {
+ _pushedVariantAccess.map(info =>
+ s"${info.columnName()}->${info.extractedSchema()}").mkString("[", ", ", "]")
+ } else {
+ "[]"
+ }
super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters.toImmutableArraySeq)) ++
Map("PushedAggregation" -> pushedAggregationsStr) ++
- Map("PushedGroupBy" -> pushedGroupByStr)
+ Map("PushedGroupBy" -> pushedGroupByStr) ++
+ Map("PushedVariantAccess" -> variantAccessStr)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala
index 08a9a306eec3..41b78881b788 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala
@@ -18,27 +18,29 @@
package org.apache.spark.sql.execution.datasources
import org.apache.spark.SparkConf
+import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.variant._
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
-class PushVariantIntoScanSuite extends SharedSparkSession {
+trait PushVariantIntoScanSuiteBase extends SharedSparkSession {
override def sparkConf: SparkConf =
super.sparkConf.set(SQLConf.PUSH_VARIANT_INTO_SCAN.key, "true")
- private def localTimeZone = spark.sessionState.conf.sessionLocalTimeZone
+ protected def localTimeZone = spark.sessionState.conf.sessionLocalTimeZone
// Return a `StructField` with the expected `VariantMetadata`.
- private def field(ordinal: Int, dataType: DataType, path: String,
+ protected def field(ordinal: Int, dataType: DataType, path: String,
failOnError: Boolean = true, timeZone: String = localTimeZone): StructField =
StructField(ordinal.toString, dataType,
metadata = VariantMetadata(path, failOnError, timeZone).toMetadata)
// Validate an `Alias` expression has the expected name and child.
- private def checkAlias(expr: Expression, expectedName: String, expected: Expression): Unit = {
+ protected def checkAlias(expr: Expression, expectedName: String, expected: Expression): Unit = {
expr match {
case Alias(child, name) =>
assert(name == expectedName)
@@ -47,9 +49,20 @@ class PushVariantIntoScanSuite extends SharedSparkSession {
}
}
+}
+
+// V1 DataSource tests with parameterized reader type
+abstract class PushVariantIntoScanV1SuiteBase extends PushVariantIntoScanSuiteBase {
+ protected def vectorizedReaderEnabled: Boolean
+ protected def readerName: String
+
+ override def sparkConf: SparkConf =
+ super.sparkConf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key,
+ vectorizedReaderEnabled.toString)
+
private def testOnFormats(fn: String => Unit): Unit = {
for (format <- Seq("PARQUET")) {
- test("test - " + format) {
+ test(s"test - $format ($readerName)") {
withTable("T") {
fn(format)
}
@@ -195,7 +208,7 @@ class PushVariantIntoScanSuite extends SharedSparkSession {
}
}
- test("No push down for JSON") {
+ test(s"No push down for JSON ($readerName)") {
withTable("T") {
sql("create table T (v variant) using JSON")
sql("select variant_get(v, '$.a') from T").queryExecution.optimizedPlan match {
@@ -207,3 +220,399 @@ class PushVariantIntoScanSuite extends SharedSparkSession {
}
}
}
+
+// V1 DataSource tests - Row-based reader
+class PushVariantIntoScanSuite extends PushVariantIntoScanV1SuiteBase {
+ override protected def vectorizedReaderEnabled: Boolean = false
+ override protected def readerName: String = "row-based reader"
+}
+
+// V1 DataSource tests - Vectorized reader
+class PushVariantIntoScanVectorizedSuite extends PushVariantIntoScanV1SuiteBase {
+ override protected def vectorizedReaderEnabled: Boolean = true
+ override protected def readerName: String = "vectorized reader"
+}
+
+// V2 DataSource tests with parameterized reader type
+abstract class PushVariantIntoScanV2SuiteBase extends QueryTest with PushVariantIntoScanSuiteBase {
+ protected def vectorizedReaderEnabled: Boolean
+ protected def readerName: String
+
+ override def sparkConf: SparkConf =
+ super.sparkConf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key,
+ vectorizedReaderEnabled.toString)
+
+ test(s"V2 test - basic variant field extraction ($readerName)") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+
+ // Use V1 to write Parquet files with actual variant data
+ withTable("temp_v1") {
+ sql(s"create table temp_v1 (v variant, s string) using PARQUET location '$path'")
+ sql("insert into temp_v1 values " +
+ "(parse_json('{\"a\": 1, \"b\": 2.5}'), 'test1'), " +
+ "(parse_json('{\"a\": 2, \"b\": 3.5}'), 'test2')")
+ }
+
+ // Use V2 to read back
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
+ val df = spark.read.parquet(path)
+ df.createOrReplaceTempView("T_V2")
+
+ val query = "select variant_get(v, '$.a', 'int') as a, v, " +
+ "cast(v as struct) as v_cast from T_V2"
+
+ val expectedRows = withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") {
+ sql(query).collect()
+ }
+
+ // Validate results are the same with and without pushdown
+ checkAnswer(sql(query), expectedRows)
+
+ // Test the variant pushdown
+ sql(query).queryExecution.optimizedPlan match {
+ case Project(projectList, scanRelation: DataSourceV2ScanRelation) =>
+ val output = scanRelation.output
+ val v = output(0)
+ // Check that variant pushdown happened - v should be a struct, not variant
+ assert(v.dataType.isInstanceOf[StructType],
+ s"Expected v to be struct type after pushdown, but got ${v.dataType}")
+ val vStruct = v.dataType.asInstanceOf[StructType]
+ assert(vStruct.fields.length == 3,
+ s"Expected 3 fields in struct, got ${vStruct.fields.length}")
+ assert(vStruct.fields(0).dataType == IntegerType)
+ assert(vStruct.fields(1).dataType == VariantType)
+ assert(vStruct.fields(2).dataType.isInstanceOf[StructType])
+ case other =>
+ fail(s"Expected V2 scan relation with variant pushdown, " +
+ s"got ${other.getClass.getName}")
+ }
+ }
+ }
+ }
+
+ test(s"V2 test - placeholder field with filter ($readerName)") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+
+ withTable("temp_v1") {
+ sql(s"create table temp_v1 (v variant) using PARQUET location '$path'")
+ sql("insert into temp_v1 values (parse_json('{\"a\": 1}'))")
+ }
+
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
+ val df = spark.read.parquet(path)
+ df.createOrReplaceTempView("T_V2")
+
+ val query = "select 1 from T_V2 where isnotnull(v)"
+
+ val expectedRows = withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") {
+ sql(query).collect()
+ }
+
+ // Validate results are the same with and without pushdown
+ checkAnswer(sql(query), expectedRows)
+
+ sql(query)
+ .queryExecution.optimizedPlan match {
+ case Project(_, Filter(condition, scanRelation: DataSourceV2ScanRelation)) =>
+ val output = scanRelation.output
+ val v = output(0)
+ assert(condition == IsNotNull(v))
+ assert(v.dataType == StructType(Array(
+ field(0, BooleanType, "$.__placeholder_field__", failOnError = false,
+ timeZone = "UTC"))))
+ case other => fail(s"Expected filtered V2 scan relation, got ${other.getClass.getName}")
+ }
+ }
+ }
+ }
+
+ test(s"V2 test - arithmetic and try_variant_get ($readerName)") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+
+ withTable("temp_v1") {
+ sql(s"create table temp_v1 (v variant) using PARQUET location '$path'")
+ sql("insert into temp_v1 values " +
+ "(parse_json('{\"a\": 1, \"b\": \"hello\"}')), " +
+ "(parse_json('{\"a\": 2, \"b\": \"world\"}'))")
+ }
+
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
+ val df = spark.read.parquet(path)
+ df.createOrReplaceTempView("T_V2")
+
+ val query = "select variant_get(v, '$.a', 'int') + 1 as a, " +
+ "try_variant_get(v, '$.b', 'string') as b from T_V2 " +
+ "where variant_get(v, '$.a', 'int') = 1"
+
+ val expectedRows = withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") {
+ sql(query).collect()
+ }
+
+ // Validate results are the same with and without pushdown
+ checkAnswer(sql(query), expectedRows)
+
+ sql(query).queryExecution.optimizedPlan match {
+ case Project(_, Filter(_, scanRelation: DataSourceV2ScanRelation)) =>
+ val output = scanRelation.output
+ val v = output(0)
+ assert(v.dataType.isInstanceOf[StructType],
+ s"Expected v to be struct type, but got ${v.dataType}")
+ val vStruct = v.dataType.asInstanceOf[StructType]
+ assert(vStruct.fields.length == 2, s"Expected 2 fields in struct")
+ assert(vStruct.fields(0).dataType == IntegerType)
+ assert(vStruct.fields(1).dataType == StringType)
+ case other => fail(s"Expected filtered V2 scan relation, got ${other.getClass.getName}")
+ }
+ }
+ }
+ }
+
+ test(s"V2 test - nested variant in struct ($readerName)") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+
+ withTable("temp_v1") {
+ sql(s"create table temp_v1 (vs struct