diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java new file mode 100644 index 000000000000..ff82e71bfd58 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; + +/** + * A mix-in interface for {@link Scan}. Data sources can implement this interface to + * support pushing down variant field access operations to the data source. + *

+ * When variant columns are accessed with specific field extractions (e.g., variant_get), + * the optimizer can push these accesses down to the data source. The data source can then + * read only the required fields from variant columns, reducing I/O and improving performance. + *

+ * The typical workflow is: + *

    + *
  1. Optimizer analyzes the query plan and identifies variant field accesses
  2. + *
  3. Optimizer calls {@link #pushVariantAccess} with the access information
  4. + *
  5. Data source validates and stores the variant access information
  6. + *
  7. Optimizer retrieves pushed information via {@link #pushedVariantAccess}
  8. + *
  9. Data source uses the information to optimize reading in {@link #readSchema()} + * and readers
  10. + *
+ * + * @since 4.1.0 + */ +@Evolving +public interface SupportsPushDownVariants extends Scan { + + /** + * Pushes down variant field access information to the data source. + *

+ * Implementations should validate if the variant accesses can be pushed down based on + * the data source's capabilities. If some accesses cannot be pushed down, the implementation + * can choose to: + *

+ *

+ * The implementation should store the variant access information that can be pushed down. + * The stored information will be retrieved later via {@link #pushedVariantAccess()}. + * + * @param variantAccessInfo Array of variant access information, one per variant column + * @return true if at least some variant accesses were pushed down, false if none were pushed + */ + boolean pushVariantAccess(VariantAccessInfo[] variantAccessInfo); + + /** + * Returns the variant access information that has been pushed down to this scan. + *

+ * This method is called by the optimizer after {@link #pushVariantAccess} to retrieve + * what variant accesses were actually accepted by the data source. The optimizer uses + * this information to rewrite the query plan. + *

+ * If {@link #pushVariantAccess} was not called or returned false, this should return + * an empty array. + * + * @return Array of pushed down variant access information + */ + VariantAccessInfo[] pushedVariantAccess(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/VariantAccessInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/VariantAccessInfo.java new file mode 100644 index 000000000000..4f61a42d0519 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/VariantAccessInfo.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import java.io.Serializable; +import java.util.Objects; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.types.StructType; + +/** + * Variant access information that describes how variant fields are accessed in a query. + *

+ * This class captures the information needed by data sources to optimize reading variant columns. + * Instead of reading the entire variant value, the data source can read only the fields that + * are actually accessed, represented as a structured schema. + *

+ * For example, if a query accesses `variant_get(v, '$.a', 'int')` and + * `variant_get(v, '$.b', 'string')`, the extracted schema would be + * `struct<0:int, 1:string>` where field ordinals correspond to the access order. + * + * @since 4.1.0 + */ +@Evolving +public final class VariantAccessInfo implements Serializable { + private final String columnName; + private final StructType extractedSchema; + + /** + * Creates variant access information for a variant column. + * + * @param columnName The name of the variant column + * @param extractedSchema The schema representing extracted fields from the variant. + * Each field represents one variant field access, with field names + * typically being ordinals (e.g., "0", "1", "2") and metadata + * containing variant-specific information like JSON path. + */ + public VariantAccessInfo(String columnName, StructType extractedSchema) { + this.columnName = Objects.requireNonNull(columnName, "columnName cannot be null"); + this.extractedSchema = + Objects.requireNonNull(extractedSchema, "extractedSchema cannot be null"); + } + + /** + * Returns the name of the variant column. + */ + public String columnName() { + return columnName; + } + + /** + * Returns the schema representing fields extracted from the variant column. + *

+ * The schema structure is: + *

+ *

+ * Data sources should use this schema to determine what fields to extract from the variant + * and what types they should be converted to. + */ + public StructType extractedSchema() { + return extractedSchema; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VariantAccessInfo that = (VariantAccessInfo) o; + return columnName.equals(that.columnName) && + extractedSchema.equals(that.extractedSchema); + } + + @Override + public int hashCode() { + return Objects.hash(columnName, extractedSchema); + } + + @Override + public String toString() { + return "VariantAccessInfo{" + + "columnName='" + columnName + '\'' + + ", extractedSchema=" + extractedSchema + + '}'; + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 9699d8a2563f..8edb59f49282 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -40,11 +40,11 @@ class SparkOptimizer( SchemaPruning, GroupBasedRowLevelOperationScanPlanning, V1Writes, - PushVariantIntoScan, V2ScanRelationPushDown, V2ScanPartitioningAndOrdering, V2Writes, - PruneFileSourcePartitions) + PruneFileSourcePartitions, + PushVariantIntoScan) override def preCBORules: Seq[Rule[LogicalPlan]] = Seq(OptimizeMetadataOnlyDeleteFromTable) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala index 6ce53e3367c4..2cf1a5e9b8cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala @@ -26,9 +26,10 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, Subquery} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns +import org.apache.spark.sql.connector.read.{SupportsPushDownVariants, VariantAccessInfo} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -280,8 +281,11 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { relation @ LogicalRelationWithTable( hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _), _)) => rewritePlan(p, projectList, filters, relation, hadoopFsRelation) - case p@PhysicalOperation(projectList, filters, relation: DataSourceV2Relation) => - rewriteV2RelationPlan(p, projectList, filters, relation) + + case p@PhysicalOperation(projectList, filters, + scanRelation @ DataSourceV2ScanRelation( + relation, scan: SupportsPushDownVariants, output, _, _)) => + rewritePlanV2(p, projectList, filters, scanRelation, scan) } } @@ -291,102 +295,135 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { filters: Seq[Expression], relation: LogicalRelation, hadoopFsRelation: HadoopFsRelation): LogicalPlan = { + val variants = new VariantInRelation + val schemaAttributes = relation.resolve(hadoopFsRelation.dataSchema, hadoopFsRelation.sparkSession.sessionState.analyzer.resolver) - - // Collect variant fields from the relation output - val variants = collectAndRewriteVariants(schemaAttributes) + val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType( + schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) + for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) { + variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil) + } if (variants.mapping.isEmpty) return originalPlan - // Collect requested fields from projections and filters projectList.foreach(variants.collectRequestedFields) filters.foreach(variants.collectRequestedFields) // `collectRequestedFields` may have removed all variant columns. if (variants.mapping.forall(_._2.isEmpty)) return originalPlan - // Build attribute map with rewritten types - val attributeMap = buildAttributeMap(schemaAttributes, variants) - - // Build new schema with variant types replaced by struct types + val attributeMap = schemaAttributes.map { a => + if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) { + val newType = variants.rewriteType(a.exprId, a.dataType, Nil) + val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)( + qualifier = a.qualifier) + (a.exprId, newAttr) + } else { + // `relation.resolve` actually returns `Seq[AttributeReference]`, although the return type + // is `Seq[Attribute]`. + (a.exprId, a.asInstanceOf[AttributeReference]) + } + }.toMap val newFields = schemaAttributes.map { a => val dataType = attributeMap(a.exprId).dataType StructField(a.name, dataType, a.nullable, a.metadata) } - // Update relation output attributes with new types val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a)) - // Update HadoopFsRelation's data schema so the file source reads the struct columns val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema = StructType(newFields))( hadoopFsRelation.sparkSession) val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq) - // Build filter and project with rewritten expressions buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap) } - private def rewriteV2RelationPlan( + // DataSource V2 rewrite method using SupportsPushDownVariants API + // Key differences from V1 implementation: + // 1. V2 uses DataSourceV2ScanRelation instead of LogicalRelation + // 2. Uses SupportsPushDownVariants API instead of directly manipulating scan + // 3. Schema is already resolved in scanRelation.output (no need for relation.resolve()) + // 4. Scan rebuilding is handled by the scan implementation via the API + // Data sources like Delta and Iceberg can implement this API to support variant pushdown. + private def rewritePlanV2( originalPlan: LogicalPlan, projectList: Seq[NamedExpression], filters: Seq[Expression], - relation: DataSourceV2Relation): LogicalPlan = { + scanRelation: DataSourceV2ScanRelation, + scan: SupportsPushDownVariants): LogicalPlan = { + val variants = new VariantInRelation - // Collect variant fields from the relation output - val variants = collectAndRewriteVariants(relation.output) + // Extract schema attributes from V2 scan relation + val schemaAttributes = scanRelation.output + + // Construct schema for default value resolution + val structSchema = StructType(schemaAttributes.map(a => + StructField(a.name, a.dataType, a.nullable, a.metadata))) + + val defaultValues = ResolveDefaultColumns.existenceDefaultValues(structSchema) + + // Add variant fields from the V2 scan schema + for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) { + variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil) + } if (variants.mapping.isEmpty) return originalPlan - // Collect requested fields from projections and filters + // Collect requested fields from project list and filters projectList.foreach(variants.collectRequestedFields) filters.foreach(variants.collectRequestedFields) - // `collectRequestedFields` may have removed all variant columns. + + // If no variant columns remain after collection, return original plan if (variants.mapping.forall(_._2.isEmpty)) return originalPlan - // Build attribute map with rewritten types - val attributeMap = buildAttributeMap(relation.output, variants) + // Build VariantAccessInfo array for the API + val variantAccessInfoArray = schemaAttributes.flatMap { attr => + variants.mapping.get(attr.exprId).flatMap(_.get(Nil)).map { fields => + // Build extracted schema for this variant column + val extractedFields = fields.toArray.sortBy(_._2).map { case (field, ordinal) => + StructField(ordinal.toString, field.targetType, metadata = field.path.toMetadata) + } + val extractedSchema = if (extractedFields.isEmpty) { + // Add placeholder field to avoid empty struct + val placeholder = VariantMetadata("$.__placeholder_field__", + failOnError = false, timeZoneId = "UTC") + StructType(Array(StructField("0", BooleanType, metadata = placeholder.toMetadata))) + } else { + StructType(extractedFields) + } + new VariantAccessInfo(attr.name, extractedSchema) + } + }.toArray - // Update relation output attributes with new types - // Note: DSv2 doesn't need to update the schema in the relation itself. The schema will be - // communicated to the data source later via V2ScanRelationPushDown.pruneColumns() API. - val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a)) - val newRelation = relation.copy(output = newOutput.toIndexedSeq) + // Call the API to push down variant access + if (variantAccessInfoArray.isEmpty) return originalPlan - // Build filter and project with rewritten expressions - buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap) - } + val pushed = scan.pushVariantAccess(variantAccessInfoArray) + if (!pushed) return originalPlan - /** - * Collect variant fields and return initialized VariantInRelation. - */ - private def collectAndRewriteVariants( - schemaAttributes: Seq[Attribute]): VariantInRelation = { - val variants = new VariantInRelation - val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType( - schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) + // Get what was actually pushed + val pushedVariantAccess = scan.pushedVariantAccess() + if (pushedVariantAccess.isEmpty) return originalPlan - for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) { - variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil) - } - - variants - } - - /** - * Build attribute map with rewritten variant types. - */ - private def buildAttributeMap( - schemaAttributes: Seq[Attribute], - variants: VariantInRelation): Map[ExprId, AttributeReference] = { - schemaAttributes.map { a => - if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) { + // Build new attribute mapping based on pushed variant access + val pushedColumnNames = pushedVariantAccess.map(_.columnName()).toSet + val attributeMap = schemaAttributes.map { a => + if (pushedColumnNames.contains(a.name) && variants.mapping.get(a.exprId).exists(_.nonEmpty)) { val newType = variants.rewriteType(a.exprId, a.dataType, Nil) val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)( qualifier = a.qualifier) (a.exprId, newAttr) } else { - // `relation.resolve` actually returns `Seq[AttributeReference]`, although the return type - // is `Seq[Attribute]`. - (a.exprId, a.asInstanceOf[AttributeReference]) + (a.exprId, a) } }.toMap + + val newOutput = scanRelation.output.map(a => attributeMap.getOrElse(a.exprId, a)) + + // The scan implementation should have updated its readSchema() based on the pushed info + // We just need to create a new scan relation with the updated output + val newScanRelation = scanRelation.copy( + output = newOutput + ) + + buildFilterAndProject(newScanRelation, projectList, filters, variants, attributeMap) } /** @@ -398,7 +435,6 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { filters: Seq[Expression], variants: VariantInRelation, attributeMap: Map[ExprId, AttributeReference]): LogicalPlan = { - val withFilter = if (filters.nonEmpty) { Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), relation) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index ec41e746469d..d347cb04f0bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -25,7 +25,7 @@ import org.apache.parquet.hadoop.ParquetInputFormat import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.connector.read.{PartitionReaderFactory, SupportsPushDownVariants, VariantAccessInfo} import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} import org.apache.spark.sql.execution.datasources.v2.FileScan @@ -47,21 +47,78 @@ case class ParquetScan( options: CaseInsensitiveStringMap, pushedAggregate: Option[Aggregation] = None, partitionFilters: Seq[Expression] = Seq.empty, - dataFilters: Seq[Expression] = Seq.empty) extends FileScan { + dataFilters: Seq[Expression] = Seq.empty, + pushedVariantAccessInfo: Array[VariantAccessInfo] = Array.empty) extends FileScan + with SupportsPushDownVariants { override def isSplitable(path: Path): Boolean = { // If aggregate is pushed down, only the file footer will be read once, // so file should not be split across multiple tasks. pushedAggregate.isEmpty } + // Build transformed schema if variant pushdown is active + private def effectiveReadDataSchema: StructType = { + if (_pushedVariantAccess.isEmpty) { + readDataSchema + } else { + // Build a mapping from column name to extracted schema + val variantSchemaMap = _pushedVariantAccess.map(info => + info.columnName() -> info.extractedSchema()).toMap + + // Transform the read data schema by replacing variant columns with their extracted schemas + StructType(readDataSchema.map { field => + variantSchemaMap.get(field.name) match { + case Some(extractedSchema) => field.copy(dataType = extractedSchema) + case None => field + } + }) + } + } + override def readSchema(): StructType = { // If aggregate is pushed down, schema has already been pruned in `ParquetScanBuilder` // and no need to call super.readSchema() - if (pushedAggregate.nonEmpty) readDataSchema else super.readSchema() + if (pushedAggregate.nonEmpty) { + effectiveReadDataSchema + } else { + // super.readSchema() combines readDataSchema + readPartitionSchema + // Apply variant transformation if variant pushdown is active + val baseSchema = super.readSchema() + if (_pushedVariantAccess.isEmpty) { + baseSchema + } else { + val variantSchemaMap = _pushedVariantAccess.map(info => + info.columnName() -> info.extractedSchema()).toMap + StructType(baseSchema.map { field => + variantSchemaMap.get(field.name) match { + case Some(extractedSchema) => field.copy(dataType = extractedSchema) + case None => field + } + }) + } + } + } + + // SupportsPushDownVariants API implementation + private var _pushedVariantAccess: Array[VariantAccessInfo] = pushedVariantAccessInfo + + override def pushVariantAccess(variantAccessInfo: Array[VariantAccessInfo]): Boolean = { + // Parquet supports variant pushdown for all variant accesses + if (variantAccessInfo.nonEmpty) { + _pushedVariantAccess = variantAccessInfo + true + } else { + false + } + } + + override def pushedVariantAccess(): Array[VariantAccessInfo] = { + _pushedVariantAccess } override def createReaderFactory(): PartitionReaderFactory = { - val readDataSchemaAsJson = readDataSchema.json + val effectiveSchema = effectiveReadDataSchema + val readDataSchemaAsJson = effectiveSchema.json hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) hadoopConf.set( ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, @@ -99,7 +156,7 @@ case class ParquetScan( conf, broadcastedConf, dataSchema, - readDataSchema, + effectiveSchema, readPartitionSchema, pushedFilters, pushedAggregate, @@ -113,8 +170,12 @@ case class ParquetScan( } else { pushedAggregate.isEmpty && p.pushedAggregate.isEmpty } + val pushedVariantEqual = + java.util.Arrays.equals(_pushedVariantAccess.asInstanceOf[Array[Object]], + p._pushedVariantAccess.asInstanceOf[Array[Object]]) super.equals(p) && dataSchema == p.dataSchema && options == p.options && - equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual + equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual && + pushedVariantEqual case _ => false } @@ -128,8 +189,15 @@ case class ParquetScan( } override def getMetaData(): Map[String, String] = { + val variantAccessStr = if (_pushedVariantAccess.nonEmpty) { + _pushedVariantAccess.map(info => + s"${info.columnName()}->${info.extractedSchema()}").mkString("[", ", ", "]") + } else { + "[]" + } super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters.toImmutableArraySeq)) ++ Map("PushedAggregation" -> pushedAggregationsStr) ++ - Map("PushedGroupBy" -> pushedGroupByStr) + Map("PushedGroupBy" -> pushedGroupByStr) ++ + Map("PushedVariantAccess" -> variantAccessStr) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala index 08a9a306eec3..41b78881b788 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala @@ -18,27 +18,29 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.SparkConf +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.variant._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ -class PushVariantIntoScanSuite extends SharedSparkSession { +trait PushVariantIntoScanSuiteBase extends SharedSparkSession { override def sparkConf: SparkConf = super.sparkConf.set(SQLConf.PUSH_VARIANT_INTO_SCAN.key, "true") - private def localTimeZone = spark.sessionState.conf.sessionLocalTimeZone + protected def localTimeZone = spark.sessionState.conf.sessionLocalTimeZone // Return a `StructField` with the expected `VariantMetadata`. - private def field(ordinal: Int, dataType: DataType, path: String, + protected def field(ordinal: Int, dataType: DataType, path: String, failOnError: Boolean = true, timeZone: String = localTimeZone): StructField = StructField(ordinal.toString, dataType, metadata = VariantMetadata(path, failOnError, timeZone).toMetadata) // Validate an `Alias` expression has the expected name and child. - private def checkAlias(expr: Expression, expectedName: String, expected: Expression): Unit = { + protected def checkAlias(expr: Expression, expectedName: String, expected: Expression): Unit = { expr match { case Alias(child, name) => assert(name == expectedName) @@ -47,9 +49,20 @@ class PushVariantIntoScanSuite extends SharedSparkSession { } } +} + +// V1 DataSource tests with parameterized reader type +abstract class PushVariantIntoScanV1SuiteBase extends PushVariantIntoScanSuiteBase { + protected def vectorizedReaderEnabled: Boolean + protected def readerName: String + + override def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, + vectorizedReaderEnabled.toString) + private def testOnFormats(fn: String => Unit): Unit = { for (format <- Seq("PARQUET")) { - test("test - " + format) { + test(s"test - $format ($readerName)") { withTable("T") { fn(format) } @@ -195,7 +208,7 @@ class PushVariantIntoScanSuite extends SharedSparkSession { } } - test("No push down for JSON") { + test(s"No push down for JSON ($readerName)") { withTable("T") { sql("create table T (v variant) using JSON") sql("select variant_get(v, '$.a') from T").queryExecution.optimizedPlan match { @@ -207,3 +220,399 @@ class PushVariantIntoScanSuite extends SharedSparkSession { } } } + +// V1 DataSource tests - Row-based reader +class PushVariantIntoScanSuite extends PushVariantIntoScanV1SuiteBase { + override protected def vectorizedReaderEnabled: Boolean = false + override protected def readerName: String = "row-based reader" +} + +// V1 DataSource tests - Vectorized reader +class PushVariantIntoScanVectorizedSuite extends PushVariantIntoScanV1SuiteBase { + override protected def vectorizedReaderEnabled: Boolean = true + override protected def readerName: String = "vectorized reader" +} + +// V2 DataSource tests with parameterized reader type +abstract class PushVariantIntoScanV2SuiteBase extends QueryTest with PushVariantIntoScanSuiteBase { + protected def vectorizedReaderEnabled: Boolean + protected def readerName: String + + override def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, + vectorizedReaderEnabled.toString) + + test(s"V2 test - basic variant field extraction ($readerName)") { + withTempPath { dir => + val path = dir.getCanonicalPath + + // Use V1 to write Parquet files with actual variant data + withTable("temp_v1") { + sql(s"create table temp_v1 (v variant, s string) using PARQUET location '$path'") + sql("insert into temp_v1 values " + + "(parse_json('{\"a\": 1, \"b\": 2.5}'), 'test1'), " + + "(parse_json('{\"a\": 2, \"b\": 3.5}'), 'test2')") + } + + // Use V2 to read back + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + val df = spark.read.parquet(path) + df.createOrReplaceTempView("T_V2") + + val query = "select variant_get(v, '$.a', 'int') as a, v, " + + "cast(v as struct) as v_cast from T_V2" + + val expectedRows = withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") { + sql(query).collect() + } + + // Validate results are the same with and without pushdown + checkAnswer(sql(query), expectedRows) + + // Test the variant pushdown + sql(query).queryExecution.optimizedPlan match { + case Project(projectList, scanRelation: DataSourceV2ScanRelation) => + val output = scanRelation.output + val v = output(0) + // Check that variant pushdown happened - v should be a struct, not variant + assert(v.dataType.isInstanceOf[StructType], + s"Expected v to be struct type after pushdown, but got ${v.dataType}") + val vStruct = v.dataType.asInstanceOf[StructType] + assert(vStruct.fields.length == 3, + s"Expected 3 fields in struct, got ${vStruct.fields.length}") + assert(vStruct.fields(0).dataType == IntegerType) + assert(vStruct.fields(1).dataType == VariantType) + assert(vStruct.fields(2).dataType.isInstanceOf[StructType]) + case other => + fail(s"Expected V2 scan relation with variant pushdown, " + + s"got ${other.getClass.getName}") + } + } + } + } + + test(s"V2 test - placeholder field with filter ($readerName)") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("temp_v1") { + sql(s"create table temp_v1 (v variant) using PARQUET location '$path'") + sql("insert into temp_v1 values (parse_json('{\"a\": 1}'))") + } + + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + val df = spark.read.parquet(path) + df.createOrReplaceTempView("T_V2") + + val query = "select 1 from T_V2 where isnotnull(v)" + + val expectedRows = withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") { + sql(query).collect() + } + + // Validate results are the same with and without pushdown + checkAnswer(sql(query), expectedRows) + + sql(query) + .queryExecution.optimizedPlan match { + case Project(_, Filter(condition, scanRelation: DataSourceV2ScanRelation)) => + val output = scanRelation.output + val v = output(0) + assert(condition == IsNotNull(v)) + assert(v.dataType == StructType(Array( + field(0, BooleanType, "$.__placeholder_field__", failOnError = false, + timeZone = "UTC")))) + case other => fail(s"Expected filtered V2 scan relation, got ${other.getClass.getName}") + } + } + } + } + + test(s"V2 test - arithmetic and try_variant_get ($readerName)") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("temp_v1") { + sql(s"create table temp_v1 (v variant) using PARQUET location '$path'") + sql("insert into temp_v1 values " + + "(parse_json('{\"a\": 1, \"b\": \"hello\"}')), " + + "(parse_json('{\"a\": 2, \"b\": \"world\"}'))") + } + + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + val df = spark.read.parquet(path) + df.createOrReplaceTempView("T_V2") + + val query = "select variant_get(v, '$.a', 'int') + 1 as a, " + + "try_variant_get(v, '$.b', 'string') as b from T_V2 " + + "where variant_get(v, '$.a', 'int') = 1" + + val expectedRows = withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") { + sql(query).collect() + } + + // Validate results are the same with and without pushdown + checkAnswer(sql(query), expectedRows) + + sql(query).queryExecution.optimizedPlan match { + case Project(_, Filter(_, scanRelation: DataSourceV2ScanRelation)) => + val output = scanRelation.output + val v = output(0) + assert(v.dataType.isInstanceOf[StructType], + s"Expected v to be struct type, but got ${v.dataType}") + val vStruct = v.dataType.asInstanceOf[StructType] + assert(vStruct.fields.length == 2, s"Expected 2 fields in struct") + assert(vStruct.fields(0).dataType == IntegerType) + assert(vStruct.fields(1).dataType == StringType) + case other => fail(s"Expected filtered V2 scan relation, got ${other.getClass.getName}") + } + } + } + } + + test(s"V2 test - nested variant in struct ($readerName)") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("temp_v1") { + sql(s"create table temp_v1 (vs struct) " + + s"using PARQUET location '$path'") + sql("insert into temp_v1 select named_struct('v1', parse_json('{\"a\": 1, \"b\": 2}'), " + + "'v2', parse_json('{\"a\": 3}'), 'i', 100)") + } + + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + val df = spark.read.parquet(path) + df.createOrReplaceTempView("T_V2") + + val query = "select variant_get(vs.v1, '$.a', 'int') as a, " + + "variant_get(vs.v1, '$.b', 'int') as b, " + + "variant_get(vs.v2, '$.a', 'int') as a2, vs.i from T_V2" + + val expectedRows = withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") { + sql(query).collect() + } + + // Validate results are the same with and without pushdown + checkAnswer(sql(query), expectedRows) + + sql(query).queryExecution.optimizedPlan match { + case Project(_, scanRelation: DataSourceV2ScanRelation) => + val output = scanRelation.output + val vs = output(0) + assert(vs.dataType.isInstanceOf[StructType]) + val vsStruct = vs.dataType.asInstanceOf[StructType] + // Should have 3 fields: v1 (struct), v2 (struct), i (int) + assert(vsStruct.fields.length == 3, s"Expected 3 fields in vs") + case other => fail(s"Expected V2 scan relation, got ${other.getClass.getName}") + } + } + } + } + + test(s"V2 test - no pushdown when struct is used ($readerName)") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("temp_v1") { + sql(s"create table temp_v1 (vs struct) " + + s"using PARQUET location '$path'") + sql("insert into temp_v1 select named_struct('v1', parse_json('{\"a\": 1}'), " + + "'v2', parse_json('{\"a\": 2}'), 'i', 100)") + } + + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + val df = spark.read.parquet(path) + df.createOrReplaceTempView("T_V2") + + val query = "select vs, variant_get(vs.v1, '$.a', 'int') as a from T_V2" + + val expectedRows = withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") { + sql(query).collect() + } + + // Validate results are the same with and without pushdown + checkAnswer(sql(query), expectedRows) + + sql(query).queryExecution.optimizedPlan match { + case Project(_, scanRelation: DataSourceV2ScanRelation) => + val output = scanRelation.output + val vs = output(0) + assert(vs.dataType.isInstanceOf[StructType]) + val vsStruct = vs.dataType.asInstanceOf[StructType] + // When struct is used directly, variants inside should NOT be pushed down + val v1Field = vsStruct.fields.find(_.name == "v1").get + assert(v1Field.dataType == VariantType, + s"Expected v1 to remain VariantType, but got ${v1Field.dataType}") + case other => fail(s"Expected V2 scan relation, got ${other.getClass.getName}") + } + } + } + } + + test(s"V2 test - no pushdown for variant in array ($readerName)") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("temp_v1") { + sql(s"create table temp_v1 (va array) using PARQUET location '$path'") + sql("insert into temp_v1 select array(parse_json('{\"a\": 1}'))") + } + + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + val df = spark.read.parquet(path) + df.createOrReplaceTempView("T_V2") + + val query = "select variant_get(va[0], '$.a', 'int') as a from T_V2" + + val expectedRows = withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") { + sql(query).collect() + } + + // Validate results are the same with and without pushdown + checkAnswer(sql(query), expectedRows) + + sql(query).queryExecution.optimizedPlan match { + case Project(_, scanRelation: DataSourceV2ScanRelation) => + val output = scanRelation.output + val va = output(0) + assert(va.dataType.isInstanceOf[ArrayType]) + val arrayType = va.dataType.asInstanceOf[ArrayType] + assert(arrayType.elementType == VariantType, + s"Expected array element to be VariantType, but got ${arrayType.elementType}") + case other => fail(s"Expected V2 scan relation, got ${other.getClass.getName}") + } + } + } + } + + test(s"V2 test - no pushdown for variant with default ($readerName)") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("temp_v1") { + sql(s"create table temp_v1 (vd variant default parse_json('1')) " + + s"using PARQUET location '$path'") + sql("insert into temp_v1 select parse_json('{\"a\": 1}')") + } + + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + val df = spark.read.parquet(path) + df.createOrReplaceTempView("T_V2") + + val query = "select variant_get(vd, '$.a', 'int') as a from T_V2" + + val expectedRows = withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") { + sql(query).collect() + } + + // Validate results are the same with and without pushdown + checkAnswer(sql(query), expectedRows) + + sql(query) + .queryExecution.optimizedPlan match { + case Project(_, scanRelation: DataSourceV2ScanRelation) => + val output = scanRelation.output + val vd = output(0) + assert(vd.dataType == VariantType, + s"Expected vd to remain VariantType, but got ${vd.dataType}") + case other => fail(s"Expected V2 scan relation, got ${other.getClass.getName}") + } + } + } + } + + test(s"V2 test - no pushdown for non-literal path ($readerName)") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("temp_v1") { + sql(s"create table temp_v1 (v variant, s string) using PARQUET location '$path'") + sql("insert into temp_v1 values (parse_json('{\"a\": 1, \"b\": 2}'), '$.a')") + } + + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + val df = spark.read.parquet(path) + df.createOrReplaceTempView("T_V2") + + val query = "select variant_get(v, '$.a', 'int') as a, " + + "variant_get(v, s, 'int') as v2, v, " + + "cast(v as struct) as v3 from T_V2" + + val expectedRows = withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") { + sql(query).collect() + } + + // Validate results are the same with and without pushdown + checkAnswer(sql(query), expectedRows) + + sql(query).queryExecution.optimizedPlan match { + case Project(_, scanRelation: DataSourceV2ScanRelation) => + val output = scanRelation.output + val v = output(0) + assert(v.dataType.isInstanceOf[StructType]) + val vStruct = v.dataType.asInstanceOf[StructType] + // Should have 3 fields: literal path extraction, full variant, cast + assert(vStruct.fields.length == 3, + s"Expected 3 fields in struct, got ${vStruct.fields.length}") + assert(vStruct.fields(0).dataType == IntegerType) + assert(vStruct.fields(1).dataType == VariantType) + assert(vStruct.fields(2).dataType.isInstanceOf[StructType]) + case other => fail(s"Expected V2 scan relation, got ${other.getClass.getName}") + } + } + } + } + + test(s"V2 No push down for JSON ($readerName)") { + withTempPath { dir => + val path = dir.getCanonicalPath + + // Use V1 to write JSON files with variant data + withTable("temp_v1_json") { + sql(s"create table temp_v1_json (v variant) using JSON location '$path'") + sql("insert into temp_v1_json values (parse_json('{\"a\": 1}'))") + } + + // Use V2 to read back + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + val df = spark.read.format("json").load(path) + df.createOrReplaceTempView("T_V2_JSON") + + val query = "select v from T_V2_JSON" + + val expectedRows = withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") { + sql(query).collect() + } + + // Validate results are the same with and without pushdown + checkAnswer(sql(query), expectedRows) + + // JSON V2 reader performs schema inference - it won't preserve variant type + // It will infer the schema as a typed struct instead + sql(query).queryExecution.optimizedPlan match { + case scanRelation: DataSourceV2ScanRelation => + val output = scanRelation.output + // JSON format with V2 infers schema, so variant becomes a typed struct + assert(output(0).dataType != VariantType, + s"Expected non-variant type for JSON V2 due to schema inference, " + + s"got ${output(0).dataType}") + case other => + fail(s"Expected V2 scan relation, got ${other.getClass.getName}") + } + } + } + } +} + +// V2 DataSource tests - Row-based reader +class PushVariantIntoScanV2Suite extends PushVariantIntoScanV2SuiteBase { + override protected def vectorizedReaderEnabled: Boolean = false + override protected def readerName: String = "row-based reader" +} + +// V2 DataSource tests - Vectorized reader +class PushVariantIntoScanV2VectorizedSuite extends PushVariantIntoScanV2SuiteBase { + override protected def vectorizedReaderEnabled: Boolean = true + override protected def readerName: String = "vectorized reader" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala deleted file mode 100644 index a6521dfe76da..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.execution.datasources.VariantMetadata -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{IntegerType, StringType, StructType, VariantType} - -class VariantV2ReadSuite extends QueryTest with SharedSparkSession { - - private val testCatalogClass = "org.apache.spark.sql.connector.catalog.InMemoryTableCatalog" - - private def withV2Catalog(f: => Unit): Unit = { - withSQLConf( - SQLConf.DEFAULT_CATALOG.key -> "testcat", - s"spark.sql.catalog.testcat" -> testCatalogClass, - SQLConf.USE_V1_SOURCE_LIST.key -> "", - SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "true", - SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "true") { - f - } - } - - test("DSV2: push variant_get fields") { - withV2Catalog { - sql("DROP TABLE IF EXISTS testcat.ns.users") - sql( - """CREATE TABLE testcat.ns.users ( - | id bigint, - | name string, - | v variant, - | vd variant default parse_json('1') - |) USING parquet""".stripMargin) - - val out = sql( - """ - |SELECT - | id, - | variant_get(v, '$.username', 'string') as username, - | variant_get(v, '$.age', 'int') as age - |FROM testcat.ns.users - |WHERE variant_get(v, '$.status', 'string') = 'active' - |""".stripMargin) - - checkAnswer(out, Seq.empty) - - // Verify variant column rewrite - val optimized = out.queryExecution.optimizedPlan - val relOutput = optimized.collectFirst { - case s: DataSourceV2ScanRelation => s.output - }.getOrElse(fail("Expected DSv2 relation in optimized plan")) - - val vAttr = relOutput.find(_.name == "v").getOrElse(fail("Missing 'v' column")) - vAttr.dataType match { - case s: StructType => - assert(s.fields.length == 3, - s"Expected 3 fields (username, age, status), got ${s.fields.length}") - assert(s.fields.forall(_.metadata.contains(VariantMetadata.METADATA_KEY)), - "All fields should have VariantMetadata") - - val paths = s.fields.map(f => VariantMetadata.fromMetadata(f.metadata).path).toSet - assert(paths == Set("$.username", "$.age", "$.status"), - s"Expected username, age, status paths, got: $paths") - - val fieldTypes = s.fields.map(_.dataType).toSet - assert(fieldTypes.contains(StringType), "Expected StringType for string fields") - assert(fieldTypes.contains(IntegerType), "Expected IntegerType for age") - - case other => - fail(s"Expected StructType for 'v', got: $other") - } - - // Verify variant with default value is NOT rewritten - relOutput.find(_.name == "vd").foreach { vdAttr => - assert(vdAttr.dataType == VariantType, - "Variant column with default value should not be rewritten") - } - } - } - - test("DSV2: nested column pruning for variant struct") { - withV2Catalog { - sql("DROP TABLE IF EXISTS testcat.ns.users2") - sql( - """CREATE TABLE testcat.ns.users2 ( - | id bigint, - | name string, - | v variant - |) USING parquet""".stripMargin) - - val out = sql( - """ - |SELECT id, variant_get(v, '$.username', 'string') as username - |FROM testcat.ns.users2 - |""".stripMargin) - - checkAnswer(out, Seq.empty) - - val scan = out.queryExecution.executedPlan.collectFirst { - case b: BatchScanExec => b.scan - }.getOrElse(fail("Expected BatchScanExec in physical plan")) - - val readSchema = scan.readSchema() - - // Verify 'v' field exists and is a struct - val vField = readSchema.fields.find(_.name == "v").getOrElse( - fail("Expected 'v' field in read schema") - ) - - vField.dataType match { - case s: StructType => - assert(s.fields.length == 1, - "Expected only 1 field ($.username) in pruned schema, got " + s.fields.length + ": " + - s.fields.map(f => VariantMetadata.fromMetadata(f.metadata).path).mkString(", ")) - - val field = s.fields(0) - assert(field.metadata.contains(VariantMetadata.METADATA_KEY), - "Field should have VariantMetadata") - - val metadata = VariantMetadata.fromMetadata(field.metadata) - assert(metadata.path == "$.username", - "Expected path '$.username', got '" + metadata.path + "'") - assert(field.dataType == StringType, - s"Expected StringType, got ${field.dataType}") - - case other => - fail(s"Expected StructType for 'v' after rewrite and pruning, got: $other") - } - } - } -}