diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 698afa4860027..5fca5de1ea254 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1342,8 +1342,8 @@ collateClause nonTrivialPrimitiveType : STRING collateClause? - | (CHARACTER | CHAR) (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)? - | VARCHAR (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)? + | (CHARACTER | CHAR) (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)? collateClause? + | VARCHAR (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)? collateClause? | (DECIMAL | DEC | NUMERIC) (LEFT_PAREN precision=INTEGER_VALUE (COMMA scale=INTEGER_VALUE)? RIGHT_PAREN)? | INTERVAL diff --git a/sql/api/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/api/src/main/java/org/apache/spark/sql/types/DataTypes.java index 0034b8e715183..1db2be1915051 100644 --- a/sql/api/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/api/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -261,20 +261,47 @@ public static StructType createStructType(StructField[] fields) { } /** - * Creates a CharType with the given length. + * Creates a StringType with the given collation. + * + * @since 4.1.0 + */ + public static StringType createStringType(int collation) { + return new StringType(collation, NoConstraint$.MODULE$); + } + + /** + * Creates a CharType with the given length and `UTF8_BINARY` collation. * * @since 4.0.0 */ public static CharType createCharType(int length) { - return new CharType(length); + return new DefaultCharType(length); } /** - * Creates a VarcharType with the given length. + * Creates a CharType with the given length and collation. + * + * @since 4.0.0 + */ + public static CharType createCharType(int length, int collationId) { + return new CharType(length, collationId); + } + + /** + * Creates a VarcharType with the given length and `UTF8_BINARY` collation. * * @since 4.0.0 */ public static VarcharType createVarcharType(int length) { - return new VarcharType(length); + return new DefaultVarcharType(length); + } + + /** + * Creates a VarcharType with the given length and collation. + * + * @since 4.0.0 + */ + public static VarcharType createVarcharType(int length, int collationId) { + return new VarcharType(length, collationId); } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 620278c66d21d..8d3ed314e5e4d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -81,9 +81,9 @@ object RowEncoder extends DataTypeErrorsBase { case DoubleType => BoxedDoubleEncoder case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization = true) case BinaryType => BinaryEncoder - case CharType(length) if SqlApiConf.get.preserveCharVarcharTypeInfo => + case CharType(length, _) if SqlApiConf.get.preserveCharVarcharTypeInfo => CharEncoder(length) - case VarcharType(length) if SqlApiConf.get.preserveCharVarcharTypeInfo => + case VarcharType(length, _) if SqlApiConf.get.preserveCharVarcharTypeInfo => VarcharEncoder(length) case s: StringType if StringHelper.isPlainString(s) => StringEncoder case TimestampType if SqlApiConf.get.datetimeJava8ApiEnabled => InstantEncoder(lenient) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index beb7061a841a8..21856def479ae 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -17,12 +17,9 @@ package org.apache.spark.sql.catalyst.parser import java.util.Locale - import scala.jdk.CollectionConverters._ - import org.antlr.v4.runtime.Token import org.antlr.v4.runtime.tree.ParseTree - import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.util.CollationFactory @@ -30,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin} import org.apache.spark.sql.connector.catalog.IdentityColumnSpec import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, TimeType, VarcharType, VariantType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DefaultCharType, DefaultVarcharType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimeType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { protected def typedVisit[T](ctx: ParseTree): T = { @@ -84,13 +81,35 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { StringType(collationId) } case CHARACTER | CHAR => - if (currentCtx.length == null) { - throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx) - } else CharType(currentCtx.length.getText.toInt) + (Option(currentCtx.length), Option(currentCtx.collateClause)) match { + case (None, _) => + throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx) + case (_, Some(_)) if !SqlApiConf.get.charVarcharCollationsEnabled => + // TODO: throw correct error + throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx) + case (Some(length), None) => + DefaultCharType(length.getText.toInt) + case (Some(length), Some(collateClause)) => + val collationNameParts = visitCollateClause(collateClause).toArray + val collationId = CollationFactory.collationNameToId( + CollationFactory.resolveFullyQualifiedName(collationNameParts)) + CharType(length.getText.toInt, collationId) + } case VARCHAR => - if (currentCtx.length == null) { - throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx) - } else VarcharType(currentCtx.length.getText.toInt) + (Option(currentCtx.length), Option(currentCtx.collateClause)) match { + case (None, _) => + throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx) + case (_, Some(_)) if !SqlApiConf.get.charVarcharCollationsEnabled => + // TODO: throw correct error + throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx) + case (Some(length), None) => + DefaultVarcharType(length.getText.toInt) + case (Some(length), Some(collateClause)) => + val collationNameParts = visitCollateClause(collateClause).toArray + val collationId = CollationFactory.collationNameToId( + CollationFactory.resolveFullyQualifiedName(collationNameParts)) + VarcharType(length.getText.toInt, collationId) + } case DECIMAL | DEC | NUMERIC => if (currentCtx.precision == null) { DecimalType.USER_DEFAULT diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala index 51b2c40f9bf2e..70c92053b812b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.types.{ArrayType, CharType, DataType, MapType, StringType, StructType, VarcharType} +import org.apache.spark.sql.types.{ArrayType, CharType, DataType, DefaultCharType, DefaultVarcharType, MapType, StringType, StructType, VarcharType} trait SparkCharVarcharUtils { @@ -54,7 +54,28 @@ trait SparkCharVarcharUtils { StructType(fields.map { field => field.copy(dataType = replaceCharVarcharWithString(field.dataType)) }) - case CharType(_) | VarcharType(_) if !SqlApiConf.get.preserveCharVarcharTypeInfo => StringType + case DefaultCharType(_) if !SqlApiConf.get.preserveCharVarcharTypeInfo => + if (!QueryTagger.getActiveTagsInLocalThread.contains(QueryTag.CHAR_TAG)) { + QueryTagger.addTags(Seq(QueryTag.CHAR_TAG)) + } + StringType + case CharType(_, collationId) if !SqlApiConf.get.preserveCharVarcharTypeInfo => + if (!QueryTagger.getActiveTagsInLocalThread.contains(QueryTag.CHAR_TAG)) { + QueryTagger.addTags(Seq(QueryTag.CHAR_TAG)) + } + if (SqlApiConf.get.charVarcharCollationsEnabled) StringType(collationId) + else StringType + case DefaultVarcharType(_) if !SqlApiConf.get.preserveCharVarcharTypeInfo => + if (!QueryTagger.getActiveTagsInLocalThread.contains(QueryTag.VARCHAR_TAG)) { + QueryTagger.addTags(Seq(QueryTag.VARCHAR_TAG)) + } + StringType + case VarcharType(_, collationId) if !SqlApiConf.get.preserveCharVarcharTypeInfo => + if (!QueryTagger.getActiveTagsInLocalThread.contains(QueryTag.VARCHAR_TAG)) { + QueryTagger.addTags(Seq(QueryTag.VARCHAR_TAG)) + } + if (SqlApiConf.get.charVarcharCollationsEnabled) StringType(collationId) + else StringType case _ => dt } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala index 76449f1704d26..f91259dd6201d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala @@ -41,6 +41,7 @@ private[sql] trait SqlApiConf { def allowNegativeScaleOfDecimalEnabled: Boolean def charVarcharAsString: Boolean def preserveCharVarcharTypeInfo: Boolean + def charVarcharCollationsEnabled: Boolean def datetimeJava8ApiEnabled: Boolean def sessionLocalTimeZone: String def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value @@ -82,6 +83,7 @@ private[sql] object DefaultSqlApiConf extends SqlApiConf { override def allowNegativeScaleOfDecimalEnabled: Boolean = false override def charVarcharAsString: Boolean = false override def preserveCharVarcharTypeInfo: Boolean = false + override def charVarcharCollationsEnabled: Boolean = false override def datetimeJava8ApiEnabled: Boolean = false override def sessionLocalTimeZone: String = TimeZone.getDefault.getID override def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value = LegacyBehaviorPolicy.CORRECTED diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/CharType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/CharType.scala index 68dad6c87c01e..f4cbff6189d1d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/CharType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/CharType.scala @@ -23,13 +23,33 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.util.CollationFactory @Experimental -case class CharType(length: Int) - extends StringType(CollationFactory.UTF8_BINARY_COLLATION_ID, FixedLength(length)) { +case class CharType(length: Int, override val collationId: Int = 0) + extends StringType(collationId, FixedLength(length)) { require(length >= 0, "The length of char type cannot be negative.") override def defaultSize: Int = length - override def typeName: String = s"char($length)" - override def jsonValue: JValue = JString(typeName) - override def toString: String = s"CharType($length)" + override def typeName: String = + if (isUTF8BinaryCollation) s"char($length)" + else s"char($length) collate $collationName" + override def jsonValue: JValue = JString(s"char($length)") + override def toString: String = + if (isUTF8BinaryCollation) s"CharType($length)" + else s"CharType($length, $collationName)" private[spark] override def asNullable: CharType = this } + +/** + * A variant of [[CharType]] defined without explicit collation. + */ +@Experimental +class DefaultCharType(override val length: Int) + extends CharType(length, CollationFactory.UTF8_BINARY_COLLATION_ID) { + override def typeName: String = s"char($length) collate $collationName" + override def toString: String = s"CharType($length, $collationName)" +} + +@Experimental +object DefaultCharType { + def apply(length: Int): DefaultCharType = new DefaultCharType(length) + def unapply(defaultCharType: DefaultCharType): Option[Int] = Some(defaultCharType.length) +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 6f4f5dc255a0a..854a06868ae51 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -328,7 +328,11 @@ object DataType { fieldPath: String, fieldType: String, collationMap: Map[String, String]): Unit = { - if (collationMap.contains(fieldPath) && fieldType != "string") { + val isValidType = fieldType match { + case "string" | CHAR_TYPE(_) | VARCHAR_TYPE(_) => true + case _ => false + } + if (collationMap.contains(fieldPath) && !isValidType) { throw new SparkIllegalArgumentException( errorClass = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS", messageParameters = Map("jsonType" -> fieldType)) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 5fec578b03581..2dc80d640e42d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -184,8 +184,8 @@ case object StringHelper extends PartialOrdering[StringConstraint] { } def removeCollation(s: StringType): StringType = s match { - case CharType(length) => CharType(length) - case VarcharType(length) => VarcharType(length) + case CharType(length, _) => DefaultCharType(length) + case VarcharType(length, _) => DefaultVarcharType(length) case _: StringType => StringType } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/VarcharType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/VarcharType.scala index 22f7947b25037..742de24417d3f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/VarcharType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/VarcharType.scala @@ -22,13 +22,34 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.util.CollationFactory @Experimental -case class VarcharType(length: Int) - extends StringType(CollationFactory.UTF8_BINARY_COLLATION_ID, MaxLength(length)) { +case class VarcharType(length: Int, override val collationId: Int = 0) + extends StringType(collationId, MaxLength(length)) { require(length >= 0, "The length of varchar type cannot be negative.") override def defaultSize: Int = length - override def typeName: String = s"varchar($length)" - override def jsonValue: JValue = JString(typeName) - override def toString: String = s"VarcharType($length)" + override def typeName: String = + if (isUTF8BinaryCollation) s"varchar($length)" + else s"varchar($length) collate $collationName" + /** [[jsonValue]] does not have collation, same as for [[StringType]] */ + override def jsonValue: JValue = JString(s"varchar($length)") + override def toString: String = + if (isUTF8BinaryCollation) s"VarcharType($length)" + else s"VarcharType($length, $collationName)" private[spark] override def asNullable: VarcharType = this } + +/** + * A variant of [[VarcharType]] defined without explicit collation. + */ +@Experimental +class DefaultVarcharType(override val length: Int) + extends VarcharType(length, CollationFactory.UTF8_BINARY_COLLATION_ID) { + override def typeName: String = s"varchar($length) collate $collationName" + override def toString: String = s"VarcharType($length, $collationName)" +} + +@Experimental +object DefaultVarcharType { + def apply(length: Int): DefaultVarcharType = new DefaultVarcharType(length) + def unapply(defaultVarcharType: DefaultVarcharType): Option[Int] = Some(defaultVarcharType.length) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 440cfb7132429..0b7234a6d4913 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -66,8 +66,8 @@ object CatalystTypeConverters { case arrayType: ArrayType => ArrayConverter(arrayType.elementType) case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType) case structType: StructType => StructConverter(structType) - case CharType(length) => new CharConverter(length) - case VarcharType(length) => new VarcharConverter(length) + case CharType(length, _) => new CharConverter(length) + case VarcharType(length, _) => new VarcharConverter(length) case _: StringType => StringConverter case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateConverter case DateType => DateConverter diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala index cf7ea21ee6f72..7a0f398f13fec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala @@ -103,7 +103,7 @@ object ApplyCharTypePaddingHelper { CharVarcharUtils .getRawType(attr.metadata) .flatMap { - case CharType(length) => + case CharType(length, _) => val (nulls, literalChars) = list.map(_.eval().asInstanceOf[UTF8String]).partition(_ == null) val literalCharLengths = literalChars.map(_.numChars()) @@ -164,7 +164,7 @@ object ApplyCharTypePaddingHelper { lit: Expression): Option[Seq[Expression]] = { if (expr.dataType == StringType) { CharVarcharUtils.getRawType(metadata).flatMap { - case CharType(length) => + case CharType(length, _) => val str = lit.eval().asInstanceOf[UTF8String] if (str == null) { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 5a58c24bc190b..68d3e8dc5f611 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1090,9 +1090,9 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString // We don't need to handle nested types here which shall fail before. def canAlterColumnType(from: DataType, to: DataType): Boolean = (from, to) match { - case (CharType(l1), CharType(l2)) => l1 == l2 - case (CharType(l1), VarcharType(l2)) => l1 <= l2 - case (VarcharType(l1), VarcharType(l2)) => l1 <= l2 + case (CharType(l1, _), CharType(l2, _)) => l1 == l2 + case (CharType(l1, _), VarcharType(l2, _)) => l1 <= l2 + case (VarcharType(l1, _), VarcharType(l2, _)) => l1 <= l2 case _ => Cast.canUpCast(from, to) } if (!canAlterColumnType(field.dataType, newDataType)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 925735654b73e..02ff5ff8bbdd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -170,7 +170,7 @@ object Literal { case _: DayTimeIntervalType if v.isInstanceOf[Duration] => Literal(CatalystTypeConverters.createToCatalystConverter(dataType)(v), dataType) case _: ObjectType => Literal(v, dataType) - case CharType(_) | VarcharType(_) if SQLConf.get.preserveCharVarcharTypeInfo => + case CharType(_, _) | VarcharType(_, _) if SQLConf.get.preserveCharVarcharTypeInfo => Literal(CatalystTypeConverters.createToCatalystConverter(dataType)(v), dataType) case _ => Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } @@ -203,10 +203,10 @@ object Literal { case t: TimeType => create(0L, t) case it: DayTimeIntervalType => create(0L, it) case it: YearMonthIntervalType => create(0, it) - case CharType(length) => + case CharType(length, _) => create(CharVarcharCodegenUtils.charTypeWriteSideCheck(UTF8String.fromString(""), length), dataType) - case VarcharType(length) => + case VarcharType(length, _) => create(CharVarcharCodegenUtils.varcharTypeWriteSideCheck(UTF8String.fromString(""), length), dataType) case st: StringType => Literal(UTF8String.fromString(""), st) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 5831a29c00a19..3331cc225b171 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -349,7 +349,7 @@ case object VariantGet { */ def checkDataType(dataType: DataType, allowStructsAndMaps: Boolean = true): Boolean = dataType match { - case CharType(_) | VarcharType(_) => false + case CharType(_, _) | VarcharType(_, _) => false case _: NumericType | BooleanType | _: StringType | BinaryType | _: DatetimeType | VariantType => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala index 1084e99731510..5590157dadf68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala @@ -40,8 +40,8 @@ object PhysicalDataType { case ShortType => PhysicalShortType case IntegerType => PhysicalIntegerType case LongType => PhysicalLongType - case VarcharType(_) => PhysicalStringType(StringType.collationId) - case CharType(_) => PhysicalStringType(StringType.collationId) + case VarcharType(_, collationId) => PhysicalStringType(collationId) + case CharType(_, collationId) => PhysicalStringType(collationId) case s: StringType => PhysicalStringType(s.collationId) case FloatType => PhysicalFloatType case DoubleType => PhysicalDoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index 6ba7e528ea230..aeae5a8f5e982 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -63,7 +63,7 @@ object CharVarcharUtils extends Logging with SparkCharVarcharUtils { StructType(fields.map { field => field.copy(dataType = replaceCharWithVarchar(field.dataType)) }) - case CharType(length) => VarcharType(length) + case CharType(length, collationId) => VarcharType(length, collationId) case _ => dt } @@ -161,25 +161,25 @@ object CharVarcharUtils extends Logging with SparkCharVarcharUtils { charFuncName: Option[String], varcharFuncName: Option[String]): Expression = { dt match { - case CharType(length) if charFuncName.isDefined => + case CharType(length, collationId) if charFuncName.isDefined => StaticInvoke( classOf[CharVarcharCodegenUtils], if (SQLConf.get.preserveCharVarcharTypeInfo) { - CharType(length) + CharType(length, collationId) } else { - StringType + StringType(collationId) }, charFuncName.get, expr :: Literal(length) :: Nil, returnNullable = false) - case VarcharType(length) if varcharFuncName.isDefined => + case VarcharType(length, collationId) if varcharFuncName.isDefined => StaticInvoke( classOf[CharVarcharCodegenUtils], if (SQLConf.get.preserveCharVarcharTypeInfo) { - VarcharType(length) + VarcharType(length, collationId) } else { - StringType + StringType(collationId) }, varcharFuncName.get, expr :: Literal(length) :: Nil, @@ -262,8 +262,10 @@ object CharVarcharUtils extends Logging with SparkCharVarcharUtils { private def typeWithWiderCharLength(type1: DataType, type2: DataType): DataType = { (type1, type2) match { - case (CharType(len1), CharType(len2)) => - CharType(math.max(len1, len2)) + case (CharType(len1, collationId1), CharType(len2, collationId2)) => + assert(collationId1 == collationId2, + "Collations of CharType should be the same for comparison.") + CharType(math.max(len1, len2), collationId1) case (StructType(fields1), StructType(fields2)) => assert(fields1.length == fields2.length) StructType(fields1.zip(fields2).map { case (left, right) => @@ -281,7 +283,10 @@ object CharVarcharUtils extends Logging with SparkCharVarcharUtils { typeWithTargetCharLength: DataType, alwaysPad: Boolean): Option[Expression] = { (rawType, typeWithTargetCharLength) match { - case (CharType(len), CharType(target)) if alwaysPad || target > len => + case (CharType(len, collationId1), CharType(target, collationId2)) + if alwaysPad || target > len => + assert(collationId1 == collationId2, + "Collations of CharType should be the same for comparison.") Some(StringRPad(expr, Literal(target))) case (StructType(fields), StructType(targets)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1d54d814a358b..0dda92dbff28c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5312,6 +5312,12 @@ object SQLConf { .checkValue(_ >= 0, "The value must be non-negative.") .createWithDefault(8) + val CHAR_VARCHAR_COLLATIONS_ENABLED = buildConf("spark.sql.charVarcharCollationEnabled") + .doc("When true, Spark allows creation of collated char/varchar types.") + .version("4.1.0") + .booleanConf + .createWithDefault(true) + val OPTIMIZE_NULL_AWARE_ANTI_JOIN = buildConf("spark.sql.optimizeNullAwareAntiJoin") .internal() @@ -7028,6 +7034,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def preserveCharVarcharTypeInfo: Boolean = getConf(SQLConf.PRESERVE_CHAR_VARCHAR_TYPE_INFO) + def charVarcharCollationsEnabled: Boolean = getConf(SQLConf.CHAR_VARCHAR_COLLATIONS_ENABLED) + def readSideCharPadding: Boolean = getConf(SQLConf.READ_SIDE_CHAR_PADDING) def cliPrintHeader: Boolean = getConf(SQLConf.CLI_PRINT_HEADER) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala index 35a30431616c8..a383a9627a8da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala @@ -80,14 +80,14 @@ private[sql] object PartitioningUtils { val normalizedVal = if (SQLConf.get.charVarcharAsString) value else normalizedFiled.dataType match { - case CharType(len) if value != null && value != DEFAULT_PARTITION_NAME => + case CharType(len, _) if value != null && value != DEFAULT_PARTITION_NAME => val v = value match { case Some(str: String) => Some(charTypeWriteSideCheck(str, len)) case str: String => charTypeWriteSideCheck(str, len) case other => other } v.asInstanceOf[T] - case VarcharType(len) if value != null && value != DEFAULT_PARTITION_NAME => + case VarcharType(len, _) if value != null && value != DEFAULT_PARTITION_NAME => val v = value match { case Some(str: String) => Some(varcharTypeWriteSideCheck(str, len)) case str: String => varcharTypeWriteSideCheck(str, len) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index 5d90c2d730e02..dff100582694f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.TimestampTypes import org.apache.spark.sql.types._ @@ -68,10 +69,14 @@ class DataTypeParserSuite extends SparkFunSuite with SQLHelper { checkDataType("timestamp_ntz", TimestampNTZType) checkDataType("timestamp_ltz", TimestampType) checkDataType("string", StringType) - checkDataType("ChaR(5)", CharType(5)) - checkDataType("ChaRacter(5)", CharType(5)) + checkDataType("ChaR(5)", DefaultCharType(5)) + checkDataType("ChaRacter(5)", DefaultCharType(5)) + checkDataType("cHaR(27)", DefaultCharType(27)) + checkDataType("chAr(5) coLLate UTf8_binary", + CharType(5, CollationFactory.UTF8_BINARY_COLLATION_ID)) + checkDataType("chAr(5) coLLate UTF8_lcaSE", + CharType(5, CollationFactory.UTF8_LCASE_COLLATION_ID)) checkDataType("varchAr(20)", VarcharType(20)) - checkDataType("cHaR(27)", CharType(27)) checkDataType("BINARY", BinaryType) checkDataType("void", NullType) checkDataType("interval", CalendarIntervalType) diff --git a/sql/connect/common/src/main/protobuf/spark/connect/types.proto b/sql/connect/common/src/main/protobuf/spark/connect/types.proto index db82cbe64a9bb..49c6c52eca41c 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/types.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/types.proto @@ -147,11 +147,13 @@ message DataType { message Char { int32 length = 1; uint32 type_variation_reference = 2; + string collation = 3; // NOLINT } message VarChar { int32 length = 1; uint32 type_variation_reference = 2; + string collation = 3; // NOLINT } message Decimal { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala index 8c83ad3d1f550..d4e9b79896686 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala @@ -175,16 +175,26 @@ object DataTypeProtoConverter { proto.DataType.Decimal.newBuilder().setPrecision(precision).setScale(scale).build()) .build() - case CharType(length) => + case CharType(length, collationId) => proto.DataType .newBuilder() - .setChar(proto.DataType.Char.newBuilder().setLength(length).build()) + .setChar( + proto.DataType.Char + .newBuilder() + .setLength(length) + .setCollation(CollationFactory.fetchCollation(collationId).collationName) + .build()) .build() - case VarcharType(length) => + case VarcharType(length, collationId) => proto.DataType .newBuilder() - .setVarChar(proto.DataType.VarChar.newBuilder().setLength(length).build()) + .setVarChar( + proto.DataType.VarChar + .newBuilder() + .setLength(length) + .setCollation(CollationFactory.fetchCollation(collationId).collationName) + .build()) .build() // StringType must be matched after CharType and VarcharType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 66548404684a8..f6aebdd849eee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -142,7 +142,7 @@ case class AnalyzeColumnCommand( case DoubleType | FloatType => true case BooleanType => true case _: DatetimeType => true - case CharType(_) | VarcharType(_) => false + case CharType(_, _) | VarcharType(_, _) => false case BinaryType | _: StringType => true case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 0077012e2b0e4..3d2916b9cfd81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -153,10 +153,10 @@ object JdbcUtils extends Logging with SQLConfHelper { case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) + case CharType(n, _) => Option(JdbcType(s"CHAR($n)", java.sql.Types.CHAR)) + case VarcharType(n, _) => Option(JdbcType(s"VARCHAR($n)", java.sql.Types.VARCHAR)) case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) - case CharType(n) => Option(JdbcType(s"CHAR($n)", java.sql.Types.CHAR)) - case VarcharType(n) => Option(JdbcType(s"VARCHAR($n)", java.sql.Types.VARCHAR)) case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) // This is a common case of timestamp without time zone. Most of the databases either only // support TIMESTAMP type or use TIMESTAMP as an alias for TIMESTAMP WITHOUT TIME ZONE. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index a9f6a727a7241..9ce4c5ab4aca8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -150,7 +150,7 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N case ByteType => Some(JdbcType("NUMBER(3)", java.sql.Types.SMALLINT)) case ShortType => Some(JdbcType("NUMBER(5)", java.sql.Types.SMALLINT)) case StringType => Some(JdbcType("VARCHAR2(255)", java.sql.Types.VARCHAR)) - case VarcharType(n) => Some(JdbcType(s"VARCHAR2($n)", java.sql.Types.VARCHAR)) + case VarcharType(n, _) => Some(JdbcType(s"VARCHAR2($n)", java.sql.Types.VARCHAR)) case TimestampType if !conf.legacyOracleTimestampMappingEnabled => Some(JdbcType("TIMESTAMP WITH LOCAL TIME ZONE", TIMESTAMP_LTZ)) case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 2e99a9afa3042..3e58e9de32053 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -190,9 +190,9 @@ trait CreatableRelationProvider { case MapType(k, v, _) => supportsDataType(k) && supportsDataType(v) case StructType(fields) => fields.forall(f => supportsDataType(f.dataType)) case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) - case BinaryType | BooleanType | ByteType | CharType(_) | DateType | _ : DecimalType | - DoubleType | FloatType | IntegerType | LongType | NullType | ObjectType(_) | ShortType | - _: StringType | TimestampNTZType | TimestampType | VarcharType(_) => true + case BinaryType | BooleanType | ByteType | CharType(_, _) | VarcharType(_, _) | DateType | + _ : DecimalType | DoubleType | FloatType | IntegerType | LongType | NullType | + ObjectType(_) | ShortType | _: StringType | TimestampNTZType | TimestampType => true case _ => false } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 6139d0e987676..887459a02b5a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -46,13 +46,13 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { val dataType = CatalystSqlParser.parseDataType(dt) checkColType(df.schema(1), dataType) dataType match { - case CharType(len) => + case CharType(len, _) => // char value will be padded if (<= len) or trimmed if (> len) val fixLenStr = if (insertVal != null) { insertVal.take(len).padTo(len, " ").mkString } else null checkAnswer(df, Row("1", fixLenStr)) - case VarcharType(len) => + case VarcharType(len, _) => // varchar value will be remained if (<= len) or trimmed if (> len) val varLenStrWithUpperBound = if (insertVal != null) { insertVal.take(len) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala b/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala index 6cd8ade41da14..1e9e759cd49d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala @@ -163,7 +163,7 @@ class TPCDSTables(spark: SparkSession, dsdgenDir: String, scaleFactor: Int) val columns = schema.fields.map { f => val c = f.dataType match { // Needs right-padding for char types - case CharType(n) => rpad(Column(f.name), n, " ") + case CharType(n, _) => rpad(Column(f.name), n, " ") // Don't need a cast for varchar types case _: VarcharType => col(f.name) case _ => col(f.name).cast(f.dataType) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala index 4560856cb0634..7d39be2045266 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala @@ -134,7 +134,7 @@ private[hive] class SparkGetColumnsOperation( case dt @ (BooleanType | _: NumericType | DateType | TimestampType | TimestampNTZType | CalendarIntervalType | NullType | _: AnsiIntervalType) => Some(dt.defaultSize) - case CharType(n) => Some(n) + case CharType(n, _) => Some(n) case StructType(fields) => val sizeArr = fields.map(f => getColumnSize(f.dataType)) if (sizeArr.contains(None)) {