Skip to content

[SPARK-52860][SQL][Test] Support V2 write schema evolution in InMemoryTable #51506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ import java.util
import java.util.OptionalLong

import scala.collection.mutable
import scala.jdk.CollectionConverters._

import com.google.common.base.Objects

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, MetadataStructFieldWithLogicalName}
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns}
import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions._
Expand All @@ -38,6 +39,7 @@ import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
Expand All @@ -50,7 +52,7 @@ import org.apache.spark.util.ArrayImplicits._
*/
abstract class InMemoryBaseTable(
val name: String,
override val columns: Array[Column],
val initialColumns: Array[Column],
override val partitioning: Array[Transform],
override val properties: util.Map[String, String],
override val constraints: Array[Constraint] = Array.empty,
Expand All @@ -68,6 +70,10 @@ abstract class InMemoryBaseTable(
// Stores the table version validated during the last `ALTER TABLE ... ADD CONSTRAINT` operation.
private var validatedTableVersion: String = null

private var tableColumns: Array[Column] = initialColumns

override def columns(): Array[Column] = tableColumns

override def currentVersion(): String = currentTableVersion.toString

def setCurrentVersion(version: String): Unit = {
Expand Down Expand Up @@ -114,7 +120,7 @@ abstract class InMemoryBaseTable(
}
}

override val schema: StructType = CatalogV2Util.v2ColumnsToStructType(columns)
override def schema(): StructType = CatalogV2Util.v2ColumnsToStructType(columns())

// purposely exposes a metadata column that conflicts with a data column in some tests
override val metadataColumns: Array[MetadataColumn] = Array(IndexColumn, PartitionKeyColumn)
Expand All @@ -127,6 +133,8 @@ abstract class InMemoryBaseTable(
private val allowUnsupportedTransforms =
properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean

private val acceptAnySchema = properties.getOrDefault("accept-any-schema", "false").toBoolean

partitioning.foreach {
case _: IdentityTransform =>
case _: YearsTransform =>
Expand Down Expand Up @@ -257,9 +265,9 @@ abstract class InMemoryBaseTable(
val newRows = new BufferedRows(to)
rows.rows.foreach { r =>
val newRow = new GenericInternalRow(r.numFields)
for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema(i).dataType))
for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema()(i).dataType))
for (i <- 0 until partitionSchema.length) {
val j = schema.fieldIndex(partitionSchema(i).name)
val j = schema().fieldIndex(partitionSchema(i).name)
newRow.update(j, to(i))
}
newRows.withRow(newRow)
Expand Down Expand Up @@ -331,14 +339,22 @@ abstract class InMemoryBaseTable(
this
}

override def capabilities: util.Set[TableCapability] = util.EnumSet.of(
def baseCapabiilities: Set[TableCapability] = Set(
TableCapability.BATCH_READ,
TableCapability.BATCH_WRITE,
TableCapability.STREAMING_WRITE,
TableCapability.OVERWRITE_BY_FILTER,
TableCapability.OVERWRITE_DYNAMIC,
TableCapability.TRUNCATE)

override def capabilities(): util.Set[TableCapability] = {
if (acceptAnySchema) {
(baseCapabiilities ++ Set(TableCapability.ACCEPT_ANY_SCHEMA)).asJava
} else {
baseCapabiilities.asJava
}
}

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new InMemoryScanBuilder(schema, options)
}
Expand Down Expand Up @@ -557,7 +573,12 @@ abstract class InMemoryBaseTable(
advisoryPartitionSize.getOrElse(0)
}

override def toBatch: BatchWrite = writer
override def toBatch: BatchWrite = {
val newSchema = info.schema()
tableColumns = CatalogV2Util.structTypeToV2Columns(
mergeSchema(CatalogV2Util.v2ColumnsToStructType(columns()), newSchema))
writer
}

override def toStreaming: StreamingWrite = streamingWriter match {
case exc: StreamingNotSupportedOperation => exc.throwsException()
Expand All @@ -571,6 +592,26 @@ abstract class InMemoryBaseTable(
override def reportDriverMetrics(): Array[CustomTaskMetric] = {
Array(new InMemoryCustomDriverTaskMetric(rows.size))
}

def mergeSchema(oldType: StructType, newType: StructType): StructType = {
val (oldFields, newFields) = (oldType.fields, newType.fields)

// this does not override the old field with the new field with same name for now
val nameToFieldMap = toFieldMap(oldFields)
val remainingNewFields = newFields.filterNot (f => nameToFieldMap.contains (f.name) )
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the code, do we assume case-sensitive column names always?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea its up to the DSV2 implementation, for example Iceberg uses a config [1], which actually is the spark config spark.sql.caseSensitive [2]. ref:

  1. https://github.com/apache/iceberg/blob/main/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java#L199
  2. https://github.com/apache/iceberg/blob/main/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java#L456

I was just making a simple example for InMemoryTable, should I can make it more complex and take this property into account?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK i added this support as well in InMemoryTable, thanks!


// Create the merged struct with the new fields are appended at the end of the struct.
StructType (oldFields ++ remainingNewFields)
}

def toFieldMap(fields: Array[StructField]): Map[String, StructField] = {
val fieldMap = fields.map(field => field.name -> field).toMap
if (SQLConf.get.caseSensitiveAnalysis) {
fieldMap
} else {
CaseInsensitiveMap(fieldMap)
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.connector

import java.util.Collections

import scala.jdk.CollectionConverters._

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode}
import org.apache.spark.sql.QueryTest.withQueryExecutionsCaptured
Expand Down Expand Up @@ -846,6 +848,45 @@ class DataSourceV2DataFrameSuite
}
}

test("SPARK-52860: insert with schema evolution") {
val tableName = "testcat.ns1.ns2.tbl"
val ident = Identifier.of(Array("ns1", "ns2"), "tbl")
Seq(true, false).foreach { caseSensitive =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
withTable(tableName) {
val tableInfo = new TableInfo.Builder().
withColumns(
Array(Column.create("c1", IntegerType)))
.withProperties(
Map("accept-any-schema" -> "true").asJava)
.build()
catalog("testcat").createTable(ident, tableInfo)

val data = Seq((1, "a"), (2, "b"), (3, "c"))
val df = if (caseSensitive) {
data.toDF("c1", "C1")
} else {
data.toDF("c1", "c2")
}
df.writeTo(tableName).append()
checkAnswer(spark.table(tableName), df)

val cols = catalog("testcat").loadTable(ident).columns()
val expectedCols = if (caseSensitive) {
Array(
Column.create("c1", IntegerType),
Column.create("C1", StringType))
} else {
Array(
Column.create("c1", IntegerType),
Column.create("c2", StringType))
}
assert(cols === expectedCols)
}
}
}
}

private def executeAndKeepPhysicalPlan[T <: SparkPlan](func: => Unit): T = {
val qe = withQueryExecutionsCaptured(spark) {
func
Expand Down