diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala index 74d13f85ee..7ffec3b5af 100644 --- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala @@ -43,6 +43,7 @@ case class SparkArrayType(elementType: SparkType) extends SparkType case class SparkMapType(keyType: SparkType, valueType: SparkType) extends SparkType case class SparkStructType(fields: Seq[SparkType]) extends SparkType case object SparkAnyType extends SparkType +case class SparkTypeWithValues(sparkType: SparkType, validValues: Seq[String]) extends SparkType case class FunctionSignature(inputTypes: Seq[SparkType]) @@ -302,7 +303,12 @@ object Meta { createFunctionWithInputTypes("hour", Seq(SparkDateOrTimestampType)), createFunctionWithInputTypes("minute", Seq(SparkDateOrTimestampType)), createFunctionWithInputTypes("second", Seq(SparkDateOrTimestampType)), - createFunctionWithInputTypes("trunc", Seq(SparkDateOrTimestampType, SparkStringType)), + createFunctionWithInputTypes( + "date_trunc", + Seq(SparkDateOrTimestampType, SparkTypeWithValues(SparkStringType, Seq("year", "week")))), + createFunctionWithInputTypes( + "trunc", + Seq(SparkDateOrTimestampType, SparkTypeWithValues(SparkStringType, Seq("year", "week")))), createFunctionWithInputTypes("year", Seq(SparkDateOrTimestampType)), createFunctionWithInputTypes("month", Seq(SparkDateOrTimestampType)), createFunctionWithInputTypes("day", Seq(SparkDateOrTimestampType)), diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala index d9e3c147d2..4b53b33897 100644 --- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala @@ -103,7 +103,7 @@ object QueryGen { val func = Utils.randomChoice(Meta.scalarFunc, r) try { val signature = Utils.randomChoice(func.signatures, r) - val args = signature.inputTypes.map(x => pickRandomColumn(r, table, x)) + val args = signature.inputTypes.map(x => pickRandomColumnOrLiteral(r, table, x)) // Example SELECT c0, log(c0) as x FROM test0 s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x " + @@ -117,6 +117,40 @@ object QueryGen { } } + private def pickRandomColumnOrLiteral( + r: Random, + df: DataFrame, + targetType: SparkType): String = { + if (r.nextBoolean) { + targetType match { + case SparkIntegralType => + r.nextInt(10).toString + case SparkStringType => + formatLiteral(r.nextString(4), SparkStringType) + case SparkTimestampType => + "now()" + case SparkTypeWithValues(sparkType, values) => + // choose between known valid input and random input + if (r.nextBoolean()) { + formatLiteral(Utils.randomChoice(values, r), sparkType) + } else { + pickRandomColumnOrLiteral(r, df, sparkType) + } + case _ => + pickRandomColumn(r, df, targetType) + } + } else { + pickRandomColumn(r, df, targetType) + } + } + + private def formatLiteral(validValue: Any, sparkType: SparkType): String = { + sparkType match { + case SparkStringType => s""""$validValue"""" + case _ => validValue.toString + } + } + private def pickRandomColumn(r: Random, df: DataFrame, targetType: SparkType): String = { targetType match { case SparkAnyType =>