diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 4bcd75a731059..59c23064858d0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -22,7 +22,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import java.util.function.BiFunction; -import java.util.function.ToLongFunction; import java.util.stream.Stream; import com.ibm.icu.text.CollationKey; @@ -125,10 +124,19 @@ public static class Collation { public final String version; /** - * Collation sensitive hash function. Output for two UTF8Strings will be the same if they are - * equal according to the collation. + * Returns the sort key of the input UTF8String. Two UTF8String values are equal iff their + * sort keys are equal (compared as byte arrays). + * The sort key is defined as follows for collations without the RTRIM modifier: + * - UTF8_BINARY: It is the bytes of the string. + * - UTF8_LCASE: It is byte array we get by replacing all invalid UTF8 sequences with the + * Unicode replacement character and then converting all characters of the replaced string + * with their lowercase equivalents (the Greek capital and Greek small sigma both map to + * the Greek final sigma). + * - ICU collations: It is the byte array returned by the ICU library for the collated string. + * For strings with the RTRIM modifier, we right-trim the string and return the collation key + * of the resulting right-trimmed string. */ - public final ToLongFunction hashFunction; + public final Function sortKeyFunction; /** * Potentially faster way than using comparator to compare two UTF8Strings for equality. @@ -182,7 +190,7 @@ public Collation( Collator collator, Comparator comparator, String version, - ToLongFunction hashFunction, + Function sortKeyFunction, BiFunction equalsFunction, boolean isUtf8BinaryType, boolean isUtf8LcaseType, @@ -192,7 +200,7 @@ public Collation( this.collator = collator; this.comparator = comparator; this.version = version; - this.hashFunction = hashFunction; + this.sortKeyFunction = sortKeyFunction; this.isUtf8BinaryType = isUtf8BinaryType; this.isUtf8LcaseType = isUtf8LcaseType; this.equalsFunction = equalsFunction; @@ -581,18 +589,18 @@ private static boolean isValidCollationId(int collationId) { protected Collation buildCollation() { if (caseSensitivity == CaseSensitivity.UNSPECIFIED) { Comparator comparator; - ToLongFunction hashFunction; + Function sortKeyFunction; BiFunction equalsFunction; boolean supportsSpaceTrimming = spaceTrimming != SpaceTrimming.NONE; if (spaceTrimming == SpaceTrimming.NONE) { comparator = UTF8String::binaryCompare; - hashFunction = s -> (long) s.hashCode(); + sortKeyFunction = s -> s.getBytes(); equalsFunction = UTF8String::equals; } else { comparator = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).binaryCompare( applyTrimmingPolicy(s2, spaceTrimming)); - hashFunction = s -> (long) applyTrimmingPolicy(s, spaceTrimming).hashCode(); + sortKeyFunction = s -> applyTrimmingPolicy(s, spaceTrimming).getBytes(); equalsFunction = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).equals( applyTrimmingPolicy(s2, spaceTrimming)); } @@ -603,25 +611,25 @@ protected Collation buildCollation() { null, comparator, CollationSpecICU.ICU_VERSION, - hashFunction, + sortKeyFunction, equalsFunction, /* isUtf8BinaryType = */ true, /* isUtf8LcaseType = */ false, spaceTrimming != SpaceTrimming.NONE); } else { Comparator comparator; - ToLongFunction hashFunction; + Function sortKeyFunction; if (spaceTrimming == SpaceTrimming.NONE) { comparator = CollationAwareUTF8String::compareLowerCase; - hashFunction = s -> - (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(); + sortKeyFunction = s -> + CollationAwareUTF8String.lowerCaseCodePoints(s).getBytes(); } else { comparator = (s1, s2) -> CollationAwareUTF8String.compareLowerCase( applyTrimmingPolicy(s1, spaceTrimming), applyTrimmingPolicy(s2, spaceTrimming)); - hashFunction = s -> (long) CollationAwareUTF8String.lowerCaseCodePoints( - applyTrimmingPolicy(s, spaceTrimming)).hashCode(); + sortKeyFunction = s -> CollationAwareUTF8String.lowerCaseCodePoints( + applyTrimmingPolicy(s, spaceTrimming)).getBytes(); } return new Collation( @@ -630,7 +638,7 @@ protected Collation buildCollation() { null, comparator, CollationSpecICU.ICU_VERSION, - hashFunction, + sortKeyFunction, (s1, s2) -> comparator.compare(s1, s2) == 0, /* isUtf8BinaryType = */ false, /* isUtf8LcaseType = */ true, @@ -1013,19 +1021,18 @@ protected Collation buildCollation() { collator.freeze(); Comparator comparator; - ToLongFunction hashFunction; + Function sortKeyFunction; if (spaceTrimming == SpaceTrimming.NONE) { - hashFunction = s -> (long) collator.getCollationKey( - s.toValidString()).hashCode(); comparator = (s1, s2) -> collator.compare(s1.toValidString(), s2.toValidString()); + sortKeyFunction = s -> collator.getCollationKey(s.toValidString()).toByteArray(); } else { comparator = (s1, s2) -> collator.compare( applyTrimmingPolicy(s1, spaceTrimming).toValidString(), applyTrimmingPolicy(s2, spaceTrimming).toValidString()); - hashFunction = s -> (long) collator.getCollationKey( - applyTrimmingPolicy(s, spaceTrimming).toValidString()).hashCode(); + sortKeyFunction = s -> collator.getCollationKey( + applyTrimmingPolicy(s, spaceTrimming).toValidString()).toByteArray(); } return new Collation( @@ -1034,7 +1041,7 @@ protected Collation buildCollation() { collator, comparator, ICU_VERSION, - hashFunction, + sortKeyFunction, (s1, s2) -> comparator.compare(s1, s2) == 0, /* isUtf8BinaryType = */ false, /* isUtf8LcaseType = */ false, diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 8e9d33efe7a6d..ef1687f4376d3 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -139,7 +139,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig case class CollationTestCase[R](collationName: String, s1: String, s2: String, expectedResult: R) - test("collation aware equality and hash") { + test("collation aware equality and sort key") { val checks = Seq( CollationTestCase("UTF8_BINARY", "aaa", "aaa", true), CollationTestCase("UTF8_BINARY", "aaa", "AAA", false), @@ -194,9 +194,9 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig assert(collation.equalsFunction(toUTF8(testCase.s1), toUTF8(testCase.s2)) == testCase.expectedResult) - val hash1 = collation.hashFunction.applyAsLong(toUTF8(testCase.s1)) - val hash2 = collation.hashFunction.applyAsLong(toUTF8(testCase.s2)) - assert((hash1 == hash2) == testCase.expectedResult) + val sortKey1 = collation.sortKeyFunction.apply(toUTF8(testCase.s1)).asInstanceOf[Array[Byte]] + val sortKey2 = collation.sortKeyFunction.apply(toUTF8(testCase.s2)).asInstanceOf[Array[Byte]] + assert(sortKey1.sameElements(sortKey2) == testCase.expectedResult) }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 7cb645e601d36..103b881005eb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -275,6 +276,8 @@ abstract class HashExpression[E] extends Expression { override def nullable: Boolean = false + protected def isCollationAware: Boolean + private def hasMapType(dt: DataType): Boolean = { dt.existsRecursively(_.isInstanceOf[MapType]) } @@ -421,6 +424,9 @@ abstract class HashExpression[E] extends Expression { s"$result = $hasherClassName.hashInt($input.months, $microsecondsHash);" } + private lazy val legacyCollationAwareHashing: Boolean = + SQLConf.get.getConf(SQLConf.COLLATION_AWARE_HASHING_ENABLED) + protected def genHashString( ctx: CodegenContext, stringType: StringType, input: String, result: String): String = { if (stringType.supportsBinaryEquality) { @@ -429,14 +435,43 @@ abstract class HashExpression[E] extends Expression { val numBytes = s"$input.numBytes()" s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" } else { - val stringHash = ctx.freshName("stringHash") - s""" - long $stringHash = CollationFactory.fetchCollation(${stringType.collationId}) - .hashFunction.applyAsLong($input); - $result = $hasherClassName.hashLong($stringHash, $result); - """ + if (isCollationAware) { + val key = ctx.freshName("key") + val offset = "Platform.BYTE_ARRAY_OFFSET" + s""" + byte[] $key = (byte[]) CollationFactory.fetchCollation(${stringType.collationId}) + .sortKeyFunction.apply($input); + $result = $hasherClassName.hashUnsafeBytes($key, $offset, $key.length, $result); + """ + } else if (legacyCollationAwareHashing) { + val collation = CollationFactory.fetchCollation(stringType.collationId) + val stringHash = ctx.freshName("stringHash") + if (collation.isUtf8BinaryType || collation.isUtf8LcaseType) { + s""" + long $stringHash = UTF8String.fromBytes((byte[]) CollationFactory + .fetchCollation(${stringType.collationId}).sortKeyFunction.apply($input)).hashCode(); + $result = $hasherClassName.hashLong($stringHash, $result); + """ + } else if (collation.supportsSpaceTrimming) { + s""" + long $stringHash = CollationFactory.fetchCollation(${stringType.collationId}) + .getCollator().getCollationKey($input.trimRight().toValidString()).hashCode(); + $result = $hasherClassName.hashLong($stringHash, $result); + """ + } else { + s""" + long $stringHash = CollationFactory.fetchCollation(${stringType.collationId}) + .getCollator().getCollationKey($input.toValidString()).hashCode(); + $result = $hasherClassName.hashLong($stringHash, $result); + """ + } + } else { + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" + } } - } protected def genHashForMap( @@ -545,11 +580,32 @@ abstract class InterpretedHashFunction { protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long + private lazy val legacyCollationAwareHashing: Boolean = + SQLConf.get.getConf(SQLConf.COLLATION_AWARE_HASHING_ENABLED) + /** - * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity - * of input `value`. + * This method is intended for callers using the old hash API and preserves compatibility for + * supported data types. It must only be used for data types that do not include collated strings + * or complex types (e.g., structs, arrays, maps) that may contain collated strings. + * + * The caller is responsible for ensuring that `dataType` does not involve collation-aware fields. + * This is validated via an internal assertion. + * + * @throws IllegalArgumentException if `dataType` contains non-UTF8 binary collation. */ def hash(value: Any, dataType: DataType, seed: Long): Long = { + require(!SchemaUtils.hasNonUTF8BinaryCollation(dataType)) + // For UTF8_BINARY, hashing behavior is the same regardless of the isCollationAware flag. + hash(value = value, dataType = dataType, seed = seed, isCollationAware = false) + } + + /** + * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity + * of input `value`. The `isCollationAware` boolean flag indicates whether hashing should take + * a string's collation into account. If not, the bytes of the string are hashed, otherwise the + * collation key of the string is hashed. + */ + def hash(value: Any, dataType: DataType, seed: Long, isCollationAware: Boolean): Long = { value match { case null => seed case b: Boolean => hashInt(if (b) 1 else 0, seed) @@ -575,12 +631,25 @@ abstract class InterpretedHashFunction { case s: UTF8String => val st = dataType.asInstanceOf[StringType] if (st.supportsBinaryEquality) { - hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) + hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes, seed) } else { - val stringHash = CollationFactory - .fetchCollation(st.collationId) - .hashFunction.applyAsLong(s) - hashLong(stringHash, seed) + if (isCollationAware) { + val key = CollationFactory.fetchCollation(st.collationId).sortKeyFunction.apply(s) + .asInstanceOf[Array[Byte]] + hashUnsafeBytes(key, Platform.BYTE_ARRAY_OFFSET, key.length, seed) + } else if (legacyCollationAwareHashing) { + val collation = CollationFactory.fetchCollation(st.collationId) + val stringHash = if (collation.isUtf8BinaryType || collation.isUtf8LcaseType) { + UTF8String.fromBytes(collation.sortKeyFunction.apply(s)).hashCode + } else if (collation.supportsSpaceTrimming) { + collation.getCollator.getCollationKey(s.trimRight.toValidString).hashCode + } else { + collation.getCollator.getCollationKey(s.toValidString).hashCode + } + hashLong(stringHash, seed) + } else { + hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes, seed) + } } case array: ArrayData => @@ -591,7 +660,7 @@ abstract class InterpretedHashFunction { var result = seed var i = 0 while (i < array.numElements()) { - result = hash(array.get(i, elementType), elementType, result) + result = hash(array.get(i, elementType), elementType, result, isCollationAware) i += 1 } result @@ -608,8 +677,8 @@ abstract class InterpretedHashFunction { var result = seed var i = 0 while (i < map.numElements()) { - result = hash(keys.get(i, kt), kt, result) - result = hash(values.get(i, vt), vt, result) + result = hash(keys.get(i, kt), kt, result, isCollationAware) + result = hash(values.get(i, vt), vt, result, isCollationAware) i += 1 } result @@ -624,7 +693,7 @@ abstract class InterpretedHashFunction { var i = 0 val len = struct.numFields while (i < len) { - result = hash(struct.get(i, types(i)), types(i), result) + result = hash(struct.get(i, types(i)), types(i), result, isCollationAware) i += 1 } result @@ -656,8 +725,10 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpress override protected def hasherClassName: String = classOf[Murmur3_x86_32].getName + override protected def isCollationAware: Boolean = false + override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { - Murmur3HashFunction.hash(value, dataType, seed).toInt + Murmur3HashFunction.hash(value, dataType, seed, isCollationAware).toInt } override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Murmur3Hash = @@ -678,6 +749,27 @@ object Murmur3HashFunction extends InterpretedHashFunction { } } +case class CollationAwareMurmur3Hash(children: Seq[Expression], seed: Int) + extends HashExpression[Int] +{ + def this(arguments: Seq[Expression]) = this(arguments, 42) + + override def dataType: DataType = IntegerType + + override def prettyName: String = "collation_aware_hash" + + override protected def hasherClassName: String = classOf[Murmur3_x86_32].getName + + override protected def isCollationAware: Boolean = true + + override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { + Murmur3HashFunction.hash(value, dataType, seed, isCollationAware).toInt + } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): + CollationAwareMurmur3Hash = copy(children = newChildren) +} + /** * A xxHash64 64-bit hash expression. */ @@ -700,8 +792,10 @@ case class XxHash64(children: Seq[Expression], seed: Long) extends HashExpressio override protected def hasherClassName: String = classOf[XXH64].getName + override protected def isCollationAware: Boolean = false + override protected def computeHash(value: Any, dataType: DataType, seed: Long): Long = { - XxHash64Function.hash(value, dataType, seed) + XxHash64Function.hash(value, dataType, seed, isCollationAware) } override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): XxHash64 = @@ -718,6 +812,27 @@ object XxHash64Function extends InterpretedHashFunction { } } +case class CollationAwareXxHash64(children: Seq[Expression], seed: Long) + extends HashExpression[Long] +{ + def this(arguments: Seq[Expression]) = this(arguments, 42L) + + override def dataType: DataType = LongType + + override def prettyName: String = "collation_aware_xxhash64" + + override protected def hasherClassName: String = classOf[XXH64].getName + + override protected def isCollationAware: Boolean = true + + override protected def computeHash(value: Any, dataType: DataType, seed: Long): Long = { + XxHash64Function.hash(value, dataType, seed, isCollationAware) + } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): + CollationAwareXxHash64 = copy(children = newChildren) +} + /** * Simulates Hive's hashing function from Hive v1.2.1 at * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() @@ -738,8 +853,10 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def hasherClassName: String = classOf[HiveHasher].getName + override protected def isCollationAware: Boolean = true + override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { - HiveHashFunction.hash(value, dataType, this.seed).toInt + HiveHashFunction.hash(value, dataType, this.seed, isCollationAware).toInt } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -825,17 +942,18 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashString( ctx: CodegenContext, stringType: StringType, input: String, result: String): String = { - if (stringType.supportsBinaryEquality) { + if (stringType.supportsBinaryEquality || !isCollationAware) { val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" } else { - val stringHash = ctx.freshName("stringHash") + val key = ctx.freshName("key") + val offset = Platform.BYTE_ARRAY_OFFSET s""" - long $stringHash = CollationFactory.fetchCollation(${stringType.collationId}) - .hashFunction.applyAsLong($input); - $result = $hasherClassName.hashLong($stringHash); + byte[] $key = (byte[]) CollationFactory.fetchCollation(${stringType.collationId}) + .sortKeyFunction.apply($input); + $result = $hasherClassName.hashUnsafeBytes($key, $offset, $key.length, $result); """ } } @@ -1018,7 +1136,7 @@ object HiveHashFunction extends InterpretedHashFunction { (result * 37) + nanoSeconds } - override def hash(value: Any, dataType: DataType, seed: Long): Long = { + override def hash(value: Any, dataType: DataType, seed: Long, isCollationAware: Boolean): Long = { value match { case null => 0 case array: ArrayData => @@ -1031,7 +1149,8 @@ object HiveHashFunction extends InterpretedHashFunction { var i = 0 val length = array.numElements() while (i < length) { - result = (31 * result) + hash(array.get(i, elementType), elementType, 0).toInt + result = (31 * result) + + hash(array.get(i, elementType), elementType, 0, isCollationAware).toInt i += 1 } result @@ -1050,7 +1169,8 @@ object HiveHashFunction extends InterpretedHashFunction { var i = 0 val length = map.numElements() while (i < length) { - result += hash(keys.get(i, kt), kt, 0).toInt ^ hash(values.get(i, vt), vt, 0).toInt + result += hash(keys.get(i, kt), kt, 0, isCollationAware).toInt ^ + hash(values.get(i, vt), vt, 0, isCollationAware).toInt i += 1 } result @@ -1066,7 +1186,8 @@ object HiveHashFunction extends InterpretedHashFunction { var i = 0 val length = struct.numFields while (i < length) { - result = (31 * result) + hash(struct.get(i, types(i)), types(i), 0).toInt + result = (31 * result) + + hash(struct.get(i, types(i)), types(i), 0, isCollationAware).toInt i += 1 } result @@ -1074,7 +1195,7 @@ object HiveHashFunction extends InterpretedHashFunction { case d: Decimal => normalizeDecimal(d.toJavaBigDecimal).hashCode() case timestamp: Long if dataType.isInstanceOf[TimestampType] => hashTimestamp(timestamp) case calendarInterval: CalendarInterval => hashCalendarInterval(calendarInterval) - case _ => super.hash(value, dataType, 0) + case _ => super.hash(value, dataType, 0, isCollationAware) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 6e19a1d6bbc8c..038105f9bfdf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -316,7 +316,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less * than numPartitions) based on hashing expressions. */ - def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) + def partitionIdExpression: Expression = Pmod( + new CollationAwareMurmur3Hash(expressions), Literal(numPartitions) + ) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala index fc947386487a1..727a490640f74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.XxHash64Function import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers.{DOUBLE_NORMALIZER, FLOAT_NORMALIZER} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String // A helper class for HyperLogLogPlusPlus. class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable { @@ -94,12 +93,10 @@ class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable { val value = dataType match { case FloatType => FLOAT_NORMALIZER.apply(_value) case DoubleType => DOUBLE_NORMALIZER.apply(_value) - case st: StringType if !st.supportsBinaryEquality => - CollationFactory.getCollationKeyBytes(_value.asInstanceOf[UTF8String], st.collationId) case _ => _value } // Create the hashed value 'x'. - val x = XxHash64Function.hash(value, dataType, 42L) + val x = XxHash64Function.hash(value, dataType, 42L, isCollationAware = true) // Determine the index of the register we are going to use. val idx = (x >>> idxShift).toInt diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala index d2bdad2d880de..2c30fc7c8f5c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala @@ -39,7 +39,8 @@ class InternalRowComparableWrapper(val row: InternalRow, val dataTypes: Seq[Data private val structType = structTypeCache.get(dataTypes) private val ordering = orderingCache.get(dataTypes) - override def hashCode(): Int = Murmur3HashFunction.hash(row, structType, 42L).toInt + override def hashCode(): Int = + Murmur3HashFunction.hash(row, structType, 42L, isCollationAware = true).toInt override def equals(other: Any): Boolean = { if (!other.isInstanceOf[InternalRowComparableWrapper]) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4d2982e91f769..812ffd64975b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -982,6 +982,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + lazy val COLLATION_AWARE_HASHING_ENABLED = + buildConf("spark.sql.legacy.collationAwareHashFunctions") + .doc("Enables collation aware hashing (legacy behavior) for collated strings in " + + "Murmur3Hash and XxHash64 user-facing expressions.") + .version("4.0.1") + .booleanConf + .createWithDefault(false) + val ICU_CASE_MAPPINGS_ENABLED = buildConf("spark.sql.icu.caseMappings.enabled") .doc("When enabled we use the ICU library (instead of the JVM) to implement case mappings" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 7cb4d5f123253..4f3efca4ad0f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, Murmur3Hash, Pmod} +import org.apache.spark.sql.catalyst.expressions.{CollationAwareMurmur3Hash, Expression, Literal, Pmod} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.IntegerType @@ -322,7 +322,7 @@ class DistributionSuite extends SparkFunSuite { val expressions = Seq($"a", $"b", $"c") val hashPartitioning = HashPartitioning(expressions, 10) hashPartitioning.partitionIdExpression match { - case Pmod(Murmur3Hash(es, 42), Literal(10, IntegerType), _) => + case Pmod(CollationAwareMurmur3Hash(es, 42), Literal(10, IntegerType), _) => assert(es.length == expressions.length && es.zip(expressions).forall { case (l, r) => l.semanticEquals(r) }) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index c64b947032885..c4e93b564deb2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CollationFactory, DateTimeUtils, GenericArrayData, IntervalUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, StructType, _} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -91,7 +92,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = { // Note : All expected hashes need to be computed using Hive 1.2.1 - val actual = HiveHashFunction.hash(input, dataType, seed = 0) + val actual = HiveHashFunction.hash(input, dataType, seed = 0, isCollationAware = true) withClue(s"hash mismatch for input = `$input` of type `$dataType`.") { assert(actual == expected) @@ -621,12 +622,18 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } for (collation <- Seq("UTF8_LCASE", "UNICODE_CI", "UTF8_BINARY")) { - test(s"hash check for collated $collation strings") { + test(s"hash check for collated $collation strings - collation aware") { val s1 = "aaa" val s2 = "AAA" - val murmur3Hash1 = Murmur3Hash(Seq(Collate(Literal(s1), ResolvedCollation(collation))), 42) - val murmur3Hash2 = Murmur3Hash(Seq(Collate(Literal(s2), ResolvedCollation(collation))), 42) + val murmur3Hash1 = CollationAwareMurmur3Hash( + Seq(Collate(Literal(s1), ResolvedCollation(collation))), + 42 + ) + val murmur3Hash2 = CollationAwareMurmur3Hash( + Seq(Collate(Literal(s2), ResolvedCollation(collation))), + 42 + ) // Interpreted hash values for s1 and s2 val interpretedHash1 = murmur3Hash1.eval() @@ -644,6 +651,115 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + for (collation <- Seq("UTF8_LCASE", "UNICODE_CI", "UTF8_BINARY")) { + test(s"hash check for collated $collation strings - collation agnostic") { + val s1 = "aaa" + val s2 = "AAA" + + val murmur3Hash1 = Murmur3Hash(Seq(Collate(Literal(s1), ResolvedCollation(collation))), 42) + val murmur3Hash2 = Murmur3Hash(Seq(Collate(Literal(s2), ResolvedCollation(collation))), 42) + + // Interpreted hash values for s1 and s2 + val interpretedHash1 = murmur3Hash1.eval() + val interpretedHash2 = murmur3Hash2.eval() + + // Check that interpreted and codegen hashes are equal + checkEvaluation(murmur3Hash1, interpretedHash1) + checkEvaluation(murmur3Hash2, interpretedHash2) + + assert(interpretedHash1 != interpretedHash2) + + // Check that the hash computed is the same as the UTF8_BINARY version of it. + if (!CollationFactory.fetchCollation(collation).isUtf8BinaryType) { + Seq[String](s1, s2).foreach { s => + val utf8BinaryStringExpr = Collate(Literal(s), ResolvedCollation("UTF8_BINARY")) + val murmur3HashBinary = Murmur3Hash(Seq(utf8BinaryStringExpr), 42) + val hashBinary = murmur3HashBinary.eval() + val murmur3Hash = Murmur3Hash(Seq(Collate(Literal(s), ResolvedCollation(collation))), 42) + val interpretedHash = murmur3Hash.eval() + assert(interpretedHash == hashBinary) + } + } + } + } + + // Below we test the `Murmur3Hash` and `XxHash64` expressions for the old behavior before the fix. + // The expected values have been computed using the old implementation of the expression. + test("SPARK-52828: always collation aware hash expression") { + withSQLConf(SQLConf.COLLATION_AWARE_HASHING_ENABLED.key -> "true") { + val testCases = Seq[(String, String, Int, Long)]( + // UTF8_BINARY + ("AAA", "UTF8_BINARY", 22125783, 3965631622972380050L), + ("AAA ", "UTF8_BINARY", 399014599, 196039582279068044L), + ("aaa", "UTF8_BINARY", -1689629761, 2465751751477118478L), + ("aaa ", "UTF8_BINARY", -1721438718, -2249763606958050730L), + // UTF8_BINARY_RTRIM + ("AAA", "UTF8_BINARY_RTRIM", -1493064582, 982928955165138586L), + ("AAA ", "UTF8_BINARY_RTRIM", -1493064582, 982928955165138586L), + ("aaa", "UTF8_BINARY_RTRIM", 2132077201, -4940759280126763524L), + ("aaa ", "UTF8_BINARY_RTRIM", 2132077201, -4940759280126763524L), + // UTF8_LCASE + ("AAA", "UTF8_LCASE", 2132077201, -4940759280126763524L), + ("AAA ", "UTF8_LCASE", -619073595, -1146641051608991690L), + ("aaa", "UTF8_LCASE", 2132077201, -4940759280126763524L), + ("aaa ", "UTF8_LCASE", -1498994355, -739345240752106297L), + // UTF8_LCASE_RTRIM + ("AAA", "UTF8_LCASE_RTRIM", 2132077201, -4940759280126763524L), + ("AAA ", "UTF8_LCASE_RTRIM", 2132077201, -4940759280126763524L), + ("aaa", "UTF8_LCASE_RTRIM", 2132077201, -4940759280126763524L), + ("aaa ", "UTF8_LCASE_RTRIM", 2132077201, -4940759280126763524L), + // UNICODE + ("AAA", "UNICODE", 128537619, 49663227161197117L), + ("AAA ", "UNICODE", 82814175, 3618364417906061797L), + ("aaa", "UNICODE", -1822783942, 290910714161494507L), + ("aaa ", "UNICODE", -896289340, 1025563887784400925L), + // UNICODE_RTRIM + ("AAA", "UNICODE_RTRIM", 128537619, 49663227161197117L), + ("AAA ", "UNICODE_RTRIM", 128537619, 49663227161197117L), + ("aaa", "UNICODE_RTRIM", -1822783942, 290910714161494507L), + ("aaa ", "UNICODE_RTRIM", -1822783942, 290910714161494507L), + // UNICODE_CI + ("AAA", "UNICODE_CI", -443043098, -6629915645815515868L), + ("AAA ", "UNICODE_CI", 667473856, -3263604567598338200L), + ("aaa", "UNICODE_CI", -443043098, -6629915645815515868L), + ("aaa ", "UNICODE_CI", -390983808, -5159733933636691741L), + // UNICODE_CI_RTRIM + ("AAA", "UNICODE_CI_RTRIM", -443043098, -6629915645815515868L), + ("AAA ", "UNICODE_CI_RTRIM", -443043098, -6629915645815515868L), + ("aaa", "UNICODE_CI_RTRIM", -443043098, -6629915645815515868L), + ("aaa ", "UNICODE_CI_RTRIM", -443043098, -6629915645815515868L) + ) + testCases.foreach { case (str, collationName, expectedMurmur3, expectedXxHash64) => + val stringExpr = Collate(Literal(str), ResolvedCollation(collationName)) + val murmur3Expr = Murmur3Hash(Seq(stringExpr), 42) + checkEvaluation(murmur3Expr, expectedMurmur3) + val xxHash64Expr = XxHash64(Seq(stringExpr), 42L) + checkEvaluation(xxHash64Expr, expectedXxHash64) + } + } + } + + test("SPARK-52828: backward-compatible hash API should reject UTF8_LCASE collation") { + // This test verifies that the legacy hash API throws an exception when used with + // collation-aware strings such as UTF8_LCASE. The assertion ensures we catch unsupported + // usage early via the internal assertion (SchemaUtils.hasNonUTF8BinaryCollation). + val expr_lcase = Collate(Literal("AAA"), ResolvedCollation("UTF8_LCASE")) + intercept[IllegalArgumentException] { + Murmur3HashFunction.hash(expr_lcase.eval(null), expr_lcase.dataType, 42) + } + intercept[IllegalArgumentException] { + XxHash64Function.hash(expr_lcase.eval(null), expr_lcase.dataType, 42) + } + intercept[IllegalArgumentException] { + HiveHashFunction.hash(expr_lcase.eval(null), expr_lcase.dataType, 42) + } + + val expr_utf8bin = Collate(Literal("AAA"), ResolvedCollation("UTF8_BINARY")) + Murmur3HashFunction.hash(expr_utf8bin.eval(null), expr_utf8bin.dataType, 42) + XxHash64Function.hash(expr_utf8bin.eval(null), expr_utf8bin.dataType, 42) + HiveHashFunction.hash(expr_utf8bin.eval(null), expr_utf8bin.dataType, 42) + } + test("SPARK-18207: Compute hash for a lot of expressions") { def checkResult(schema: StructType, input: InternalRow): Unit = { val exprs = schema.fields.zipWithIndex.map { case (f, i) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala index 6069127a0df9d..d8416351079a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.benchmark import scala.concurrent.duration._ import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.sql.catalyst.expressions.Murmur3HashFunction import org.apache.spark.sql.catalyst.util.{CollationFactory, CollationSupport} +import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String abstract class CollationBenchmarkBase extends BenchmarkBase { @@ -92,7 +94,7 @@ abstract class CollationBenchmarkBase extends BenchmarkBase { sublistStrings.foreach { _ => utf8Strings.foreach { s => (0 to 3).foreach { _ => - collation.hashFunction.applyAsLong(s) + Murmur3HashFunction.hash(s, StringType(collationType), 42L, true).toInt } } }