From 1be5fe4ac7e49329b8df5322909b6510e58ceb03 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Tue, 15 Jul 2025 17:52:04 -0700 Subject: [PATCH 1/4] [MINOR][SQL] Add unit test for V2 write schema evolution --- .../sql/connector/catalog/InMemoryBaseTable.scala | 3 ++- .../sql/connector/DataSourceV2DataFrameSuite.scala | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index d032829d8645f..f7a129a326e43 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -337,7 +337,8 @@ abstract class InMemoryBaseTable( TableCapability.STREAMING_WRITE, TableCapability.OVERWRITE_BY_FILTER, TableCapability.OVERWRITE_DYNAMIC, - TableCapability.TRUNCATE) + TableCapability.TRUNCATE, + TableCapability.ACCEPT_ANY_SCHEMA) override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new InMemoryScanBuilder(schema, options) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index 342eefa1a6f63..c83de583cbc52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -846,6 +846,20 @@ class DataSourceV2DataFrameSuite } } + test("insert with schema evolution") { + val tableName = "testcat.ns1.ns2.tbl" + withTable(tableName) { + val columns = Array( + Column.create("c1", IntegerType) + ) + val tableInfo = new TableInfo.Builder().withColumns(columns).build() + catalog("testcat").createTable(Identifier.of(Array("ns1", "ns2"), "tbl"), tableInfo) + Seq((1, "a"), (2, "b"), (3, "c")).toDF("c1", "c2") + .writeTo(tableName) + .append() + } + } + private def executeAndKeepPhysicalPlan[T <: SparkPlan](func: => Unit): T = { val qe = withQueryExecutionsCaptured(spark) { func From 548fb1f76321a2ccdaa55a593e3a9bc75698ea56 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Wed, 16 Jul 2025 16:08:02 -0700 Subject: [PATCH 2/4] Test fix --- .../connector/catalog/InMemoryBaseTable.scala | 43 ++++++++++++++++--- .../DataSourceV2DataFrameSuite.scala | 28 ++++++++---- 2 files changed, 57 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index f7a129a326e43..39afae2b10cb7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -23,6 +23,7 @@ import java.util import java.util.OptionalLong import scala.collection.mutable +import scala.jdk.CollectionConverters._ import com.google.common.base.Objects @@ -50,7 +51,7 @@ import org.apache.spark.util.ArrayImplicits._ */ abstract class InMemoryBaseTable( val name: String, - override val columns: Array[Column], + val initialColumns: Array[Column], override val partitioning: Array[Transform], override val properties: util.Map[String, String], override val constraints: Array[Constraint] = Array.empty, @@ -68,6 +69,10 @@ abstract class InMemoryBaseTable( // Stores the table version validated during the last `ALTER TABLE ... ADD CONSTRAINT` operation. private var validatedTableVersion: String = null + private var tableColumns: Array[Column] = initialColumns + + override def columns(): Array[Column] = tableColumns + override def currentVersion(): String = currentTableVersion.toString def setCurrentVersion(version: String): Unit = { @@ -114,7 +119,7 @@ abstract class InMemoryBaseTable( } } - override val schema: StructType = CatalogV2Util.v2ColumnsToStructType(columns) + override def schema(): StructType = CatalogV2Util.v2ColumnsToStructType(columns()) // purposely exposes a metadata column that conflicts with a data column in some tests override val metadataColumns: Array[MetadataColumn] = Array(IndexColumn, PartitionKeyColumn) @@ -127,6 +132,8 @@ abstract class InMemoryBaseTable( private val allowUnsupportedTransforms = properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean + private val acceptAnySchema = properties.getOrDefault("accept-any-schema", "false").toBoolean + partitioning.foreach { case _: IdentityTransform => case _: YearsTransform => @@ -257,9 +264,9 @@ abstract class InMemoryBaseTable( val newRows = new BufferedRows(to) rows.rows.foreach { r => val newRow = new GenericInternalRow(r.numFields) - for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema(i).dataType)) + for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema()(i).dataType)) for (i <- 0 until partitionSchema.length) { - val j = schema.fieldIndex(partitionSchema(i).name) + val j = schema().fieldIndex(partitionSchema(i).name) newRow.update(j, to(i)) } newRows.withRow(newRow) @@ -331,7 +338,7 @@ abstract class InMemoryBaseTable( this } - override def capabilities: util.Set[TableCapability] = util.EnumSet.of( + def baseCapabiilities: Set[TableCapability] = Set( TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.STREAMING_WRITE, @@ -340,6 +347,14 @@ abstract class InMemoryBaseTable( TableCapability.TRUNCATE, TableCapability.ACCEPT_ANY_SCHEMA) + override def capabilities(): util.Set[TableCapability] = { + if (acceptAnySchema) { + (baseCapabiilities ++ Set(TableCapability.ACCEPT_ANY_SCHEMA)).asJava + } else { + baseCapabiilities.asJava + } + } + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new InMemoryScanBuilder(schema, options) } @@ -558,7 +573,12 @@ abstract class InMemoryBaseTable( advisoryPartitionSize.getOrElse(0) } - override def toBatch: BatchWrite = writer + override def toBatch: BatchWrite = { + val newSchema = info.schema() + tableColumns = CatalogV2Util.structTypeToV2Columns( + mergeSchema(CatalogV2Util.v2ColumnsToStructType(columns()), newSchema)) + writer + } override def toStreaming: StreamingWrite = streamingWriter match { case exc: StreamingNotSupportedOperation => exc.throwsException() @@ -572,6 +592,17 @@ abstract class InMemoryBaseTable( override def reportDriverMetrics(): Array[CustomTaskMetric] = { Array(new InMemoryCustomDriverTaskMetric(rows.size)) } + + def mergeSchema(oldType: StructType, newType: StructType): StructType = { + val (oldFields, newFields) = (oldType.fields, newType.fields) + + // this does not override the old field with the new field with same name for now + val nameToFieldMap = oldFields.map (f => f.name -> f).toMap + val remainingNewFields = newFields.filterNot (f => nameToFieldMap.contains (f.name) ) + + // Create the merged struct with the new fields are appended at the end of the struct. + StructType (oldFields ++ remainingNewFields) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index c83de583cbc52..4996268ea7732 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.connector import java.util.Collections +import scala.jdk.CollectionConverters._ + import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} import org.apache.spark.sql.QueryTest.withQueryExecutionsCaptured @@ -848,15 +850,25 @@ class DataSourceV2DataFrameSuite test("insert with schema evolution") { val tableName = "testcat.ns1.ns2.tbl" + val ident = Identifier.of(Array("ns1", "ns2"), "tbl") withTable(tableName) { - val columns = Array( - Column.create("c1", IntegerType) - ) - val tableInfo = new TableInfo.Builder().withColumns(columns).build() - catalog("testcat").createTable(Identifier.of(Array("ns1", "ns2"), "tbl"), tableInfo) - Seq((1, "a"), (2, "b"), (3, "c")).toDF("c1", "c2") - .writeTo(tableName) - .append() + val tableInfo = new TableInfo.Builder(). + withColumns( + Array(Column.create("c1", IntegerType))) + .withProperties( + Map("accept-any-schema" -> "true").asJava) + .build() + catalog("testcat").createTable(ident, tableInfo) + + val data = Seq((1, "a"), (2, "b"), (3, "c")).toDF("c1", "c2") + data.writeTo(tableName).append() + + checkAnswer(spark.table(tableName), data) + val cols = catalog("testcat").loadTable(ident).columns() + val expectedCols = Array( + Column.create("c1", IntegerType), + Column.create("c2", StringType)) + assert(cols === expectedCols) } } From 59a6d18abfa1cd48141b94a15497f0bc5a7db032 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Wed, 16 Jul 2025 18:49:47 -0700 Subject: [PATCH 3/4] Fix test 2 --- .../apache/spark/sql/connector/catalog/InMemoryBaseTable.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 39afae2b10cb7..bdfab0c6b439d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -344,8 +344,7 @@ abstract class InMemoryBaseTable( TableCapability.STREAMING_WRITE, TableCapability.OVERWRITE_BY_FILTER, TableCapability.OVERWRITE_DYNAMIC, - TableCapability.TRUNCATE, - TableCapability.ACCEPT_ANY_SCHEMA) + TableCapability.TRUNCATE) override def capabilities(): util.Set[TableCapability] = { if (acceptAnySchema) { From 393862c6d76234a518b14f24b85c6d178bf474b8 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Thu, 17 Jul 2025 17:46:49 -0700 Subject: [PATCH 4/4] review comments --- .../connector/catalog/InMemoryBaseTable.scala | 14 ++++- .../DataSourceV2DataFrameSuite.scala | 53 ++++++++++++------- 2 files changed, 46 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index bdfab0c6b439d..4efb08a67829e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -29,7 +29,7 @@ import com.google.common.base.Objects import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, MetadataStructFieldWithLogicalName} -import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns} import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions._ @@ -39,6 +39,7 @@ import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -596,12 +597,21 @@ abstract class InMemoryBaseTable( val (oldFields, newFields) = (oldType.fields, newType.fields) // this does not override the old field with the new field with same name for now - val nameToFieldMap = oldFields.map (f => f.name -> f).toMap + val nameToFieldMap = toFieldMap(oldFields) val remainingNewFields = newFields.filterNot (f => nameToFieldMap.contains (f.name) ) // Create the merged struct with the new fields are appended at the end of the struct. StructType (oldFields ++ remainingNewFields) } + + def toFieldMap(fields: Array[StructField]): Map[String, StructField] = { + val fieldMap = fields.map(field => field.name -> field).toMap + if (SQLConf.get.caseSensitiveAnalysis) { + fieldMap + } else { + CaseInsensitiveMap(fieldMap) + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index 4996268ea7732..652c6275fe8d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -848,27 +848,42 @@ class DataSourceV2DataFrameSuite } } - test("insert with schema evolution") { + test("SPARK-52860: insert with schema evolution") { val tableName = "testcat.ns1.ns2.tbl" val ident = Identifier.of(Array("ns1", "ns2"), "tbl") - withTable(tableName) { - val tableInfo = new TableInfo.Builder(). - withColumns( - Array(Column.create("c1", IntegerType))) - .withProperties( - Map("accept-any-schema" -> "true").asJava) - .build() - catalog("testcat").createTable(ident, tableInfo) - - val data = Seq((1, "a"), (2, "b"), (3, "c")).toDF("c1", "c2") - data.writeTo(tableName).append() - - checkAnswer(spark.table(tableName), data) - val cols = catalog("testcat").loadTable(ident).columns() - val expectedCols = Array( - Column.create("c1", IntegerType), - Column.create("c2", StringType)) - assert(cols === expectedCols) + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + withTable(tableName) { + val tableInfo = new TableInfo.Builder(). + withColumns( + Array(Column.create("c1", IntegerType))) + .withProperties( + Map("accept-any-schema" -> "true").asJava) + .build() + catalog("testcat").createTable(ident, tableInfo) + + val data = Seq((1, "a"), (2, "b"), (3, "c")) + val df = if (caseSensitive) { + data.toDF("c1", "C1") + } else { + data.toDF("c1", "c2") + } + df.writeTo(tableName).append() + checkAnswer(spark.table(tableName), df) + + val cols = catalog("testcat").loadTable(ident).columns() + val expectedCols = if (caseSensitive) { + Array( + Column.create("c1", IntegerType), + Column.create("C1", StringType)) + } else { + Array( + Column.create("c1", IntegerType), + Column.create("c2", StringType)) + } + assert(cols === expectedCols) + } + } } }