diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index d032829d8645f..4efb08a67829e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -23,12 +23,13 @@ import java.util import java.util.OptionalLong import scala.collection.mutable +import scala.jdk.CollectionConverters._ import com.google.common.base.Objects import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, MetadataStructFieldWithLogicalName} -import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns} import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions._ @@ -38,6 +39,7 @@ import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -50,7 +52,7 @@ import org.apache.spark.util.ArrayImplicits._ */ abstract class InMemoryBaseTable( val name: String, - override val columns: Array[Column], + val initialColumns: Array[Column], override val partitioning: Array[Transform], override val properties: util.Map[String, String], override val constraints: Array[Constraint] = Array.empty, @@ -68,6 +70,10 @@ abstract class InMemoryBaseTable( // Stores the table version validated during the last `ALTER TABLE ... ADD CONSTRAINT` operation. private var validatedTableVersion: String = null + private var tableColumns: Array[Column] = initialColumns + + override def columns(): Array[Column] = tableColumns + override def currentVersion(): String = currentTableVersion.toString def setCurrentVersion(version: String): Unit = { @@ -114,7 +120,7 @@ abstract class InMemoryBaseTable( } } - override val schema: StructType = CatalogV2Util.v2ColumnsToStructType(columns) + override def schema(): StructType = CatalogV2Util.v2ColumnsToStructType(columns()) // purposely exposes a metadata column that conflicts with a data column in some tests override val metadataColumns: Array[MetadataColumn] = Array(IndexColumn, PartitionKeyColumn) @@ -127,6 +133,8 @@ abstract class InMemoryBaseTable( private val allowUnsupportedTransforms = properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean + private val acceptAnySchema = properties.getOrDefault("accept-any-schema", "false").toBoolean + partitioning.foreach { case _: IdentityTransform => case _: YearsTransform => @@ -257,9 +265,9 @@ abstract class InMemoryBaseTable( val newRows = new BufferedRows(to) rows.rows.foreach { r => val newRow = new GenericInternalRow(r.numFields) - for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema(i).dataType)) + for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema()(i).dataType)) for (i <- 0 until partitionSchema.length) { - val j = schema.fieldIndex(partitionSchema(i).name) + val j = schema().fieldIndex(partitionSchema(i).name) newRow.update(j, to(i)) } newRows.withRow(newRow) @@ -331,7 +339,7 @@ abstract class InMemoryBaseTable( this } - override def capabilities: util.Set[TableCapability] = util.EnumSet.of( + def baseCapabiilities: Set[TableCapability] = Set( TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.STREAMING_WRITE, @@ -339,6 +347,14 @@ abstract class InMemoryBaseTable( TableCapability.OVERWRITE_DYNAMIC, TableCapability.TRUNCATE) + override def capabilities(): util.Set[TableCapability] = { + if (acceptAnySchema) { + (baseCapabiilities ++ Set(TableCapability.ACCEPT_ANY_SCHEMA)).asJava + } else { + baseCapabiilities.asJava + } + } + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new InMemoryScanBuilder(schema, options) } @@ -557,7 +573,12 @@ abstract class InMemoryBaseTable( advisoryPartitionSize.getOrElse(0) } - override def toBatch: BatchWrite = writer + override def toBatch: BatchWrite = { + val newSchema = info.schema() + tableColumns = CatalogV2Util.structTypeToV2Columns( + mergeSchema(CatalogV2Util.v2ColumnsToStructType(columns()), newSchema)) + writer + } override def toStreaming: StreamingWrite = streamingWriter match { case exc: StreamingNotSupportedOperation => exc.throwsException() @@ -571,6 +592,26 @@ abstract class InMemoryBaseTable( override def reportDriverMetrics(): Array[CustomTaskMetric] = { Array(new InMemoryCustomDriverTaskMetric(rows.size)) } + + def mergeSchema(oldType: StructType, newType: StructType): StructType = { + val (oldFields, newFields) = (oldType.fields, newType.fields) + + // this does not override the old field with the new field with same name for now + val nameToFieldMap = toFieldMap(oldFields) + val remainingNewFields = newFields.filterNot (f => nameToFieldMap.contains (f.name) ) + + // Create the merged struct with the new fields are appended at the end of the struct. + StructType (oldFields ++ remainingNewFields) + } + + def toFieldMap(fields: Array[StructField]): Map[String, StructField] = { + val fieldMap = fields.map(field => field.name -> field).toMap + if (SQLConf.get.caseSensitiveAnalysis) { + fieldMap + } else { + CaseInsensitiveMap(fieldMap) + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index 342eefa1a6f63..652c6275fe8d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.connector import java.util.Collections +import scala.jdk.CollectionConverters._ + import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} import org.apache.spark.sql.QueryTest.withQueryExecutionsCaptured @@ -846,6 +848,45 @@ class DataSourceV2DataFrameSuite } } + test("SPARK-52860: insert with schema evolution") { + val tableName = "testcat.ns1.ns2.tbl" + val ident = Identifier.of(Array("ns1", "ns2"), "tbl") + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + withTable(tableName) { + val tableInfo = new TableInfo.Builder(). + withColumns( + Array(Column.create("c1", IntegerType))) + .withProperties( + Map("accept-any-schema" -> "true").asJava) + .build() + catalog("testcat").createTable(ident, tableInfo) + + val data = Seq((1, "a"), (2, "b"), (3, "c")) + val df = if (caseSensitive) { + data.toDF("c1", "C1") + } else { + data.toDF("c1", "c2") + } + df.writeTo(tableName).append() + checkAnswer(spark.table(tableName), df) + + val cols = catalog("testcat").loadTable(ident).columns() + val expectedCols = if (caseSensitive) { + Array( + Column.create("c1", IntegerType), + Column.create("C1", StringType)) + } else { + Array( + Column.create("c1", IntegerType), + Column.create("c2", StringType)) + } + assert(cols === expectedCols) + } + } + } + } + private def executeAndKeepPhysicalPlan[T <: SparkPlan](func: => Unit): T = { val qe = withQueryExecutionsCaptured(spark) { func