Skip to content

[DO NOT REVIEW] temp #51501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1342,8 +1342,8 @@ collateClause

nonTrivialPrimitiveType
: STRING collateClause?
| (CHARACTER | CHAR) (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)?
| VARCHAR (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)?
| (CHARACTER | CHAR) (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)? collateClause?
| VARCHAR (LEFT_PAREN length=INTEGER_VALUE RIGHT_PAREN)? collateClause?
| (DECIMAL | DEC | NUMERIC)
(LEFT_PAREN precision=INTEGER_VALUE (COMMA scale=INTEGER_VALUE)? RIGHT_PAREN)?
| INTERVAL
Expand Down
35 changes: 31 additions & 4 deletions sql/api/src/main/java/org/apache/spark/sql/types/DataTypes.java
Original file line number Diff line number Diff line change
Expand Up @@ -261,20 +261,47 @@ public static StructType createStructType(StructField[] fields) {
}

/**
* Creates a CharType with the given length.
* Creates a StringType with the given collation.
*
* @since 4.1.0
*/
public static StringType createStringType(int collation) {
return new StringType(collation, NoConstraint$.MODULE$);
}

/**
* Creates a CharType with the given length and `UTF8_BINARY` collation.
*
* @since 4.0.0
*/
public static CharType createCharType(int length) {
return new CharType(length);
return new DefaultCharType(length);
}

/**
* Creates a VarcharType with the given length.
* Creates a CharType with the given length and collation.
*
* @since 4.0.0
*/
public static CharType createCharType(int length, int collationId) {
return new CharType(length, collationId);
}

/**
* Creates a VarcharType with the given length and `UTF8_BINARY` collation.
*
* @since 4.0.0
*/
public static VarcharType createVarcharType(int length) {
return new VarcharType(length);
return new DefaultVarcharType(length);
}

/**
* Creates a VarcharType with the given length and collation.
*
* @since 4.0.0
*/
public static VarcharType createVarcharType(int length, int collationId) {
return new VarcharType(length, collationId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ object RowEncoder extends DataTypeErrorsBase {
case DoubleType => BoxedDoubleEncoder
case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization = true)
case BinaryType => BinaryEncoder
case CharType(length) if SqlApiConf.get.preserveCharVarcharTypeInfo =>
case CharType(length, _) if SqlApiConf.get.preserveCharVarcharTypeInfo =>
CharEncoder(length)
case VarcharType(length) if SqlApiConf.get.preserveCharVarcharTypeInfo =>
case VarcharType(length, _) if SqlApiConf.get.preserveCharVarcharTypeInfo =>
VarcharEncoder(length)
case s: StringType if StringHelper.isPlainString(s) => StringEncoder
case TimestampType if SqlApiConf.get.datetimeJava8ApiEnabled => InstantEncoder(lenient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,17 @@
package org.apache.spark.sql.catalyst.parser

import java.util.Locale

import scala.jdk.CollectionConverters._

import org.antlr.v4.runtime.Token
import org.antlr.v4.runtime.tree.ParseTree

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin}
import org.apache.spark.sql.connector.catalog.IdentityColumnSpec
import org.apache.spark.sql.errors.QueryParsingErrors
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, TimeType, VarcharType, VariantType, YearMonthIntervalType}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DefaultCharType, DefaultVarcharType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimeType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType}

class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
protected def typedVisit[T](ctx: ParseTree): T = {
Expand Down Expand Up @@ -84,13 +81,35 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
StringType(collationId)
}
case CHARACTER | CHAR =>
if (currentCtx.length == null) {
throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx)
} else CharType(currentCtx.length.getText.toInt)
(Option(currentCtx.length), Option(currentCtx.collateClause)) match {
case (None, _) =>
throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx)
case (_, Some(_)) if !SqlApiConf.get.charVarcharCollationsEnabled =>
// TODO: throw correct error
throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx)
case (Some(length), None) =>
DefaultCharType(length.getText.toInt)
case (Some(length), Some(collateClause)) =>
val collationNameParts = visitCollateClause(collateClause).toArray
val collationId = CollationFactory.collationNameToId(
CollationFactory.resolveFullyQualifiedName(collationNameParts))
CharType(length.getText.toInt, collationId)
}
case VARCHAR =>
if (currentCtx.length == null) {
throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx)
} else VarcharType(currentCtx.length.getText.toInt)
(Option(currentCtx.length), Option(currentCtx.collateClause)) match {
case (None, _) =>
throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx)
case (_, Some(_)) if !SqlApiConf.get.charVarcharCollationsEnabled =>
// TODO: throw correct error
throw QueryParsingErrors.charVarcharTypeMissingLengthError(typeCtx.getText, ctx)
case (Some(length), None) =>
DefaultVarcharType(length.getText.toInt)
case (Some(length), Some(collateClause)) =>
val collationNameParts = visitCollateClause(collateClause).toArray
val collationId = CollationFactory.collationNameToId(
CollationFactory.resolveFullyQualifiedName(collationNameParts))
VarcharType(length.getText.toInt, collationId)
}
case DECIMAL | DEC | NUMERIC =>
if (currentCtx.precision == null) {
DecimalType.USER_DEFAULT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.errors.DataTypeErrors
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{ArrayType, CharType, DataType, MapType, StringType, StructType, VarcharType}
import org.apache.spark.sql.types.{ArrayType, CharType, DataType, DefaultCharType, DefaultVarcharType, MapType, StringType, StructType, VarcharType}

trait SparkCharVarcharUtils {

Expand Down Expand Up @@ -54,7 +54,28 @@ trait SparkCharVarcharUtils {
StructType(fields.map { field =>
field.copy(dataType = replaceCharVarcharWithString(field.dataType))
})
case CharType(_) | VarcharType(_) if !SqlApiConf.get.preserveCharVarcharTypeInfo => StringType
case DefaultCharType(_) if !SqlApiConf.get.preserveCharVarcharTypeInfo =>
if (!QueryTagger.getActiveTagsInLocalThread.contains(QueryTag.CHAR_TAG)) {
QueryTagger.addTags(Seq(QueryTag.CHAR_TAG))
}
StringType
case CharType(_, collationId) if !SqlApiConf.get.preserveCharVarcharTypeInfo =>
if (!QueryTagger.getActiveTagsInLocalThread.contains(QueryTag.CHAR_TAG)) {
QueryTagger.addTags(Seq(QueryTag.CHAR_TAG))
}
if (SqlApiConf.get.charVarcharCollationsEnabled) StringType(collationId)
else StringType
case DefaultVarcharType(_) if !SqlApiConf.get.preserveCharVarcharTypeInfo =>
if (!QueryTagger.getActiveTagsInLocalThread.contains(QueryTag.VARCHAR_TAG)) {
QueryTagger.addTags(Seq(QueryTag.VARCHAR_TAG))
}
StringType
case VarcharType(_, collationId) if !SqlApiConf.get.preserveCharVarcharTypeInfo =>
if (!QueryTagger.getActiveTagsInLocalThread.contains(QueryTag.VARCHAR_TAG)) {
QueryTagger.addTags(Seq(QueryTag.VARCHAR_TAG))
}
if (SqlApiConf.get.charVarcharCollationsEnabled) StringType(collationId)
else StringType
case _ => dt
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ private[sql] trait SqlApiConf {
def allowNegativeScaleOfDecimalEnabled: Boolean
def charVarcharAsString: Boolean
def preserveCharVarcharTypeInfo: Boolean
def charVarcharCollationsEnabled: Boolean
def datetimeJava8ApiEnabled: Boolean
def sessionLocalTimeZone: String
def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value
Expand Down Expand Up @@ -82,6 +83,7 @@ private[sql] object DefaultSqlApiConf extends SqlApiConf {
override def allowNegativeScaleOfDecimalEnabled: Boolean = false
override def charVarcharAsString: Boolean = false
override def preserveCharVarcharTypeInfo: Boolean = false
override def charVarcharCollationsEnabled: Boolean = false
override def datetimeJava8ApiEnabled: Boolean = false
override def sessionLocalTimeZone: String = TimeZone.getDefault.getID
override def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value = LegacyBehaviorPolicy.CORRECTED
Expand Down
30 changes: 25 additions & 5 deletions sql/api/src/main/scala/org/apache/spark/sql/types/CharType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,33 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.util.CollationFactory

@Experimental
case class CharType(length: Int)
extends StringType(CollationFactory.UTF8_BINARY_COLLATION_ID, FixedLength(length)) {
case class CharType(length: Int, override val collationId: Int = 0)
extends StringType(collationId, FixedLength(length)) {
require(length >= 0, "The length of char type cannot be negative.")

override def defaultSize: Int = length
override def typeName: String = s"char($length)"
override def jsonValue: JValue = JString(typeName)
override def toString: String = s"CharType($length)"
override def typeName: String =
if (isUTF8BinaryCollation) s"char($length)"
else s"char($length) collate $collationName"
override def jsonValue: JValue = JString(s"char($length)")
override def toString: String =
if (isUTF8BinaryCollation) s"CharType($length)"
else s"CharType($length, $collationName)"
private[spark] override def asNullable: CharType = this
}

/**
* A variant of [[CharType]] defined without explicit collation.
*/
@Experimental
class DefaultCharType(override val length: Int)
extends CharType(length, CollationFactory.UTF8_BINARY_COLLATION_ID) {
override def typeName: String = s"char($length) collate $collationName"
override def toString: String = s"CharType($length, $collationName)"
}

@Experimental
object DefaultCharType {
def apply(length: Int): DefaultCharType = new DefaultCharType(length)
def unapply(defaultCharType: DefaultCharType): Option[Int] = Some(defaultCharType.length)
}
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,11 @@ object DataType {
fieldPath: String,
fieldType: String,
collationMap: Map[String, String]): Unit = {
if (collationMap.contains(fieldPath) && fieldType != "string") {
val isValidType = fieldType match {
case "string" | CHAR_TYPE(_) | VARCHAR_TYPE(_) => true
case _ => false
}
if (collationMap.contains(fieldPath) && !isValidType) {
throw new SparkIllegalArgumentException(
errorClass = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS",
messageParameters = Map("jsonType" -> fieldType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ case object StringHelper extends PartialOrdering[StringConstraint] {
}

def removeCollation(s: StringType): StringType = s match {
case CharType(length) => CharType(length)
case VarcharType(length) => VarcharType(length)
case CharType(length, _) => DefaultCharType(length)
case VarcharType(length, _) => DefaultVarcharType(length)
case _: StringType => StringType
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,34 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.util.CollationFactory

@Experimental
case class VarcharType(length: Int)
extends StringType(CollationFactory.UTF8_BINARY_COLLATION_ID, MaxLength(length)) {
case class VarcharType(length: Int, override val collationId: Int = 0)
extends StringType(collationId, MaxLength(length)) {
require(length >= 0, "The length of varchar type cannot be negative.")

override def defaultSize: Int = length
override def typeName: String = s"varchar($length)"
override def jsonValue: JValue = JString(typeName)
override def toString: String = s"VarcharType($length)"
override def typeName: String =
if (isUTF8BinaryCollation) s"varchar($length)"
else s"varchar($length) collate $collationName"
/** [[jsonValue]] does not have collation, same as for [[StringType]] */
override def jsonValue: JValue = JString(s"varchar($length)")
override def toString: String =
if (isUTF8BinaryCollation) s"VarcharType($length)"
else s"VarcharType($length, $collationName)"
private[spark] override def asNullable: VarcharType = this
}

/**
* A variant of [[VarcharType]] defined without explicit collation.
*/
@Experimental
class DefaultVarcharType(override val length: Int)
extends VarcharType(length, CollationFactory.UTF8_BINARY_COLLATION_ID) {
override def typeName: String = s"varchar($length) collate $collationName"
override def toString: String = s"VarcharType($length, $collationName)"
}

@Experimental
object DefaultVarcharType {
def apply(length: Int): DefaultVarcharType = new DefaultVarcharType(length)
def unapply(defaultVarcharType: DefaultVarcharType): Option[Int] = Some(defaultVarcharType.length)
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ object CatalystTypeConverters {
case arrayType: ArrayType => ArrayConverter(arrayType.elementType)
case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType)
case structType: StructType => StructConverter(structType)
case CharType(length) => new CharConverter(length)
case VarcharType(length) => new VarcharConverter(length)
case CharType(length, _) => new CharConverter(length)
case VarcharType(length, _) => new VarcharConverter(length)
case _: StringType => StringConverter
case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateConverter
case DateType => DateConverter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ object ApplyCharTypePaddingHelper {
CharVarcharUtils
.getRawType(attr.metadata)
.flatMap {
case CharType(length) =>
case CharType(length, _) =>
val (nulls, literalChars) =
list.map(_.eval().asInstanceOf[UTF8String]).partition(_ == null)
val literalCharLengths = literalChars.map(_.numChars())
Expand Down Expand Up @@ -164,7 +164,7 @@ object ApplyCharTypePaddingHelper {
lit: Expression): Option[Seq[Expression]] = {
if (expr.dataType == StringType) {
CharVarcharUtils.getRawType(metadata).flatMap {
case CharType(length) =>
case CharType(length, _) =>
val str = lit.eval().asInstanceOf[UTF8String]
if (str == null) {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1090,9 +1090,9 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString

// We don't need to handle nested types here which shall fail before.
def canAlterColumnType(from: DataType, to: DataType): Boolean = (from, to) match {
case (CharType(l1), CharType(l2)) => l1 == l2
case (CharType(l1), VarcharType(l2)) => l1 <= l2
case (VarcharType(l1), VarcharType(l2)) => l1 <= l2
case (CharType(l1, _), CharType(l2, _)) => l1 == l2
case (CharType(l1, _), VarcharType(l2, _)) => l1 <= l2
case (VarcharType(l1, _), VarcharType(l2, _)) => l1 <= l2
case _ => Cast.canUpCast(from, to)
}
if (!canAlterColumnType(field.dataType, newDataType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ object Literal {
case _: DayTimeIntervalType if v.isInstanceOf[Duration] =>
Literal(CatalystTypeConverters.createToCatalystConverter(dataType)(v), dataType)
case _: ObjectType => Literal(v, dataType)
case CharType(_) | VarcharType(_) if SQLConf.get.preserveCharVarcharTypeInfo =>
case CharType(_, _) | VarcharType(_, _) if SQLConf.get.preserveCharVarcharTypeInfo =>
Literal(CatalystTypeConverters.createToCatalystConverter(dataType)(v), dataType)
case _ => Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
}
Expand Down Expand Up @@ -203,10 +203,10 @@ object Literal {
case t: TimeType => create(0L, t)
case it: DayTimeIntervalType => create(0L, it)
case it: YearMonthIntervalType => create(0, it)
case CharType(length) =>
case CharType(length, _) =>
create(CharVarcharCodegenUtils.charTypeWriteSideCheck(UTF8String.fromString(""), length),
dataType)
case VarcharType(length) =>
case VarcharType(length, _) =>
create(CharVarcharCodegenUtils.varcharTypeWriteSideCheck(UTF8String.fromString(""), length),
dataType)
case st: StringType => Literal(UTF8String.fromString(""), st)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ case object VariantGet {
*/
def checkDataType(dataType: DataType, allowStructsAndMaps: Boolean = true): Boolean =
dataType match {
case CharType(_) | VarcharType(_) => false
case CharType(_, _) | VarcharType(_, _) => false
case _: NumericType | BooleanType | _: StringType | BinaryType | _: DatetimeType |
VariantType =>
true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ object PhysicalDataType {
case ShortType => PhysicalShortType
case IntegerType => PhysicalIntegerType
case LongType => PhysicalLongType
case VarcharType(_) => PhysicalStringType(StringType.collationId)
case CharType(_) => PhysicalStringType(StringType.collationId)
case VarcharType(_, collationId) => PhysicalStringType(collationId)
case CharType(_, collationId) => PhysicalStringType(collationId)
case s: StringType => PhysicalStringType(s.collationId)
case FloatType => PhysicalFloatType
case DoubleType => PhysicalDoubleType
Expand Down
Loading