Skip to content

Commit 2c299d5

Browse files
authored
chore: Various improvements to checkSparkAnswer* methods in CometTestBase (#2656)
1 parent 2ed0967 commit 2c299d5

File tree

11 files changed

+301
-179
lines changed

11 files changed

+301
-179
lines changed

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,13 @@ object CometConf extends ShimCometConf {
702702
.bytesConf(ByteUnit.BYTE)
703703
.createWithDefault(100L * 1024 * 1024 * 1024) // 100 GB
704704

705+
val COMET_STRICT_TESTING: ConfigEntry[Boolean] = conf(s"$COMET_PREFIX.testing.strict")
706+
.category(CATEGORY_TESTING)
707+
.doc("Experimental option to enable strict testing, which will fail tests that could be " +
708+
"more comprehensive, such as checking for a specific fallback reason")
709+
.booleanConf
710+
.createWithDefault(sys.env.getOrElse("ENABLE_COMET_STRICT_TESTING", "false").toBoolean)
711+
705712
/** Create a config to enable a specific operator */
706713
private def createExecEnabledConfig(
707714
exec: String,

docs/source/user-guide/latest/compatibility.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ because they are handled well in Spark (e.g., `SQLOrderingUtil.compareFloats`).
9797
functions of arrow-rs used by DataFusion do not normalize NaN and zero (e.g., [arrow::compute::kernels::cmp::eq](https://docs.rs/arrow/latest/arrow/compute/kernels/cmp/fn.eq.html#)).
9898
So Comet will add additional normalization expression of NaN and zero for comparison.
9999

100-
Sorting on floating-point data types (or complex types containing floating-point values) is not compatible with
100+
Sorting on floating-point data types (or complex types containing floating-point values) is not compatible with
101101
Spark if the data contains both zero and negative zero. This is likely an edge case that is not of concern for many users
102102
and sorting on floating-point data can be enabled by setting `spark.comet.expression.SortOrder.allowIncompatible=true`.
103103

docs/source/user-guide/latest/configs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ These settings can be used to determine which parts of the plan are accelerated
130130
| `spark.comet.exec.onHeap.enabled` | Whether to allow Comet to run in on-heap mode. Required for running Spark SQL tests. | false |
131131
| `spark.comet.exec.onHeap.memoryPool` | The type of memory pool to be used for Comet native execution when running Spark in on-heap mode. Available pool types are `greedy`, `fair_spill`, `greedy_task_shared`, `fair_spill_task_shared`, `greedy_global`, `fair_spill_global`, and `unbounded`. | greedy_task_shared |
132132
| `spark.comet.memoryOverhead` | The amount of additional memory to be allocated per executor process for Comet, in MiB, when running Spark in on-heap mode. | 1024 MiB |
133+
| `spark.comet.testing.strict` | Experimental option to enable strict testing, which will fail tests that could be more comprehensive, such as checking for a specific fallback reason | false |
133134
<!--END:CONFIG_TABLE-->
134135

135136
## Enabling or Disabling Individual Operators
@@ -274,6 +275,7 @@ These settings can be used to determine which parts of the plan are accelerated
274275
| `spark.comet.expression.ShiftRight.enabled` | Enable Comet acceleration for `ShiftRight` | true |
275276
| `spark.comet.expression.Signum.enabled` | Enable Comet acceleration for `Signum` | true |
276277
| `spark.comet.expression.Sin.enabled` | Enable Comet acceleration for `Sin` | true |
278+
| `spark.comet.expression.SortOrder.enabled` | Enable Comet acceleration for `SortOrder` | true |
277279
| `spark.comet.expression.SparkPartitionID.enabled` | Enable Comet acceleration for `SparkPartitionID` | true |
278280
| `spark.comet.expression.Sqrt.enabled` | Enable Comet acceleration for `Sqrt` | true |
279281
| `spark.comet.expression.StartsWith.enabled` | Enable Comet acceleration for `StartsWith` | true |

spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ class ExtendedExplainInfo extends ExtendedExplainGenerator {
3838
if (CometConf.COMET_EXPLAIN_VERBOSE_ENABLED.get()) {
3939
generateVerboseExtendedInfo(plan)
4040
} else {
41-
val info = extensionInfo(plan)
41+
val info = getFallbackReasons(plan)
4242
info.toSeq.sorted.mkString("\n").trim
4343
}
4444
}
4545

46-
private[comet] def extensionInfo(node: TreeNode[_]): Set[String] = {
46+
def getFallbackReasons(node: TreeNode[_]): Set[String] = {
4747
var info = mutable.Seq[String]()
4848
val sorted = sortup(node)
4949
sorted.foreach { p =>

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
649649
// config is enabled)
650650
if (CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.get()) {
651651
val info = new ExtendedExplainInfo()
652-
if (info.extensionInfo(newPlan).nonEmpty) {
652+
if (info.getFallbackReasons(newPlan).nonEmpty) {
653653
logWarning(
654654
"Comet cannot execute some parts of this plan natively " +
655655
s"(set ${CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key}=false " +

spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,9 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
128128
.createOrReplaceTempView("t2")
129129
val expectedFallbackReasons = HashSet(
130130
"data type not supported: ArrayType(StructType(StructField(_1,BooleanType,true),StructField(_2,ByteType,true)),false)")
131-
// note that checkExtended is disabled here due to an unrelated issue
132-
// https://github.com/apache/datafusion-comet/issues/1313
133-
checkSparkAnswerAndCompareExplainPlan(
131+
checkSparkAnswerAndFallbackReasons(
134132
sql("SELECT array_remove(a, b) FROM t2"),
135-
expectedFallbackReasons,
136-
checkExplainString = false)
133+
expectedFallbackReasons)
137134
}
138135
}
139136
}

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -846,10 +846,16 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
846846
"SECOND",
847847
"MILLISECOND",
848848
"MICROSECOND").foreach { format =>
849-
checkSparkAnswer(
850-
"SELECT " +
851-
s"date_trunc('$format', ts )" +
852-
" from int96timetbl")
849+
val sql = "SELECT " +
850+
s"date_trunc('$format', ts )" +
851+
" from int96timetbl"
852+
853+
if (conversionEnabled) {
854+
// plugin is disabled if PARQUET_INT96_TIMESTAMP_CONVERSION is true
855+
checkSparkAnswer(sql)
856+
} else {
857+
checkSparkAnswerAndOperator(sql)
858+
}
853859
}
854860
}
855861
}
@@ -978,7 +984,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
978984
sql(s"create table $table(id int, name varchar(20)) using parquet")
979985
sql(s"insert into $table values(1,'James Smith')")
980986
val query = sql(s"select cast(id as string) from $table")
981-
val (_, cometPlan) = checkSparkAnswer(query)
987+
val (_, cometPlan) = checkSparkAnswerAndOperator(query)
982988
val project = cometPlan
983989
.asInstanceOf[WholeStageCodegenExec]
984990
.child
@@ -1343,17 +1349,19 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
13431349
"sin",
13441350
"sqrt",
13451351
"tan")) {
1346-
val df = checkSparkAnswerWithTol(s"SELECT $expr(_1), $expr(_2) FROM tbl")
1347-
val cometProjectExecs = collect(df.queryExecution.executedPlan) {
1348-
case op: CometProjectExec => op
1352+
val (_, cometPlan) =
1353+
checkSparkAnswerAndOperatorWithTol(sql(s"SELECT $expr(_1), $expr(_2) FROM tbl"))
1354+
val cometProjectExecs = collect(cometPlan) { case op: CometProjectExec =>
1355+
op
13491356
}
13501357
assert(cometProjectExecs.length == 1, expr)
13511358
}
13521359
// expressions with two args
13531360
for (expr <- Seq("atan2", "pow")) {
1354-
val df = checkSparkAnswerWithTol(s"SELECT $expr(_1, _2) FROM tbl")
1355-
val cometProjectExecs = collect(df.queryExecution.executedPlan) {
1356-
case op: CometProjectExec => op
1361+
val (_, cometPlan) =
1362+
checkSparkAnswerAndOperatorWithTol(sql(s"SELECT $expr(_1, _2) FROM tbl"))
1363+
val cometProjectExecs = collect(cometPlan) { case op: CometProjectExec =>
1364+
op
13571365
}
13581366
assert(cometProjectExecs.length == 1, expr)
13591367
}
@@ -1364,8 +1372,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
13641372
val testValuesRepeated = doubleValues.flatMap(v => Seq.fill(1000)(v))
13651373
for (withDictionary <- Seq(true, false)) {
13661374
withParquetTable(testValuesRepeated.map(n => (n, n)), "tbl", withDictionary) {
1367-
val df = checkSparkAnswerWithTol(s"SELECT $expr(_1) FROM tbl")
1368-
val projections = collect(df.queryExecution.executedPlan) { case p: CometProjectExec =>
1375+
val (_, cometPlan) = checkSparkAnswerAndOperatorWithTol(sql(s"SELECT $expr(_1) FROM tbl"))
1376+
val projections = collect(cometPlan) { case p: CometProjectExec =>
13691377
p
13701378
}
13711379
assert(projections.length == 1)
@@ -1381,10 +1389,12 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
13811389
}
13821390
withParquetTable(Seq(0, 1, 2).map(n => (n, n)), "tbl") {
13831391
val sql = "select _1+_2 from tbl"
1384-
val (_, cometPlan) = checkSparkAnswer(sql)
1392+
val (_, cometPlan) = checkSparkAnswerAndOperator(sql)
13851393
assert(0 == countSparkProjectExec(cometPlan))
13861394
withSQLConf(CometConf.getExprEnabledConfigKey("Add") -> "false") {
1387-
val (_, cometPlan) = checkSparkAnswer(sql)
1395+
val (_, cometPlan) = checkSparkAnswerAndFallbackReason(
1396+
sql,
1397+
"Expression support is disabled. Set spark.comet.expression.Add.enabled=true to enable it.")
13881398
assert(1 == countSparkProjectExec(cometPlan))
13891399
}
13901400
}
@@ -1401,7 +1411,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14011411
val (_, cometPlan) = checkSparkAnswer(sql)
14021412
assert(1 == countSparkProjectExec(cometPlan))
14031413
withSQLConf(CometConf.getExprAllowIncompatConfigKey("InitCap") -> "true") {
1404-
val (_, cometPlan) = checkSparkAnswer(sql)
1414+
val (_, cometPlan) = checkSparkAnswerAndOperator(sql)
14051415
assert(0 == countSparkProjectExec(cometPlan))
14061416
}
14071417
}
@@ -1677,7 +1687,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
16771687
s"SELECT * FROM $table WHERE name in ('Smith', 'Brown', NULL)")
16781688

16791689
// TODO: why with not in, the plan is only `LocalTableScan`?
1680-
checkSparkAnswer(s"SELECT * FROM $table WHERE id not in (1)")
1690+
checkSparkAnswerAndOperator(s"SELECT * FROM $table WHERE id not in (1)")
16811691
checkSparkAnswer(s"SELECT * FROM $table WHERE name not in ('Smith', 'Brown', NULL)")
16821692
}
16831693
}
@@ -2005,7 +2015,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
20052015
val expected = test._2
20062016
val df = sql(qry)
20072017
df.collect() // force an execution
2008-
checkSparkAnswerAndCompareExplainPlan(df, expected)
2018+
checkSparkAnswerAndFallbackReasons(df, expected)
20092019
})
20102020
}
20112021
}
@@ -2030,7 +2040,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
20302040
val expected = test._2
20312041
val df = sql(qry)
20322042
df.collect() // force an execution
2033-
checkSparkAnswerAndCompareExplainPlan(df, expected)
2043+
checkSparkAnswerAndFallbackReasons(df, expected)
20342044
})
20352045
}
20362046
}

spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class CometStringExpressionSuite extends CometTestBase {
8585
if (isLiteralStr && isLiteralLen && isLiteralPad) {
8686
// all arguments are literal, so Spark constant folding will kick in
8787
// and pad function will not be evaluated by Comet
88-
checkSparkAnswer(sql)
88+
checkSparkAnswerAndOperator(sql)
8989
} else if (isLiteralStr) {
9090
checkSparkAnswerAndFallbackReason(
9191
sql,
@@ -135,7 +135,7 @@ class CometStringExpressionSuite extends CometTestBase {
135135
if (isLiteralStr && isLiteralLen && isLiteralPad) {
136136
// all arguments are literal, so Spark constant folding will kick in
137137
// and pad function will not be evaluated by Comet
138-
checkSparkAnswer(sql)
138+
checkSparkAnswerAndOperator(sql)
139139
} else {
140140
// Comet will fall back to Spark because the plan contains a staticinvoke instruction
141141
// which is not supported

spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,7 +1105,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11051105
"SELECT _g2, AVG(_7) FROM tbl GROUP BY _g2",
11061106
expectedNumOfCometAggregates)
11071107

1108-
checkSparkAnswerWithTol("SELECT _g3, AVG(_8) FROM tbl GROUP BY _g3")
1108+
checkSparkAnswerWithTolerance("SELECT _g3, AVG(_8) FROM tbl GROUP BY _g3")
11091109
assert(getNumCometHashAggregate(
11101110
sql("SELECT _g3, AVG(_8) FROM tbl GROUP BY _g3")) == expectedNumOfCometAggregates)
11111111

@@ -1117,7 +1117,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11171117
"SELECT AVG(_7) FROM tbl",
11181118
expectedNumOfCometAggregates)
11191119

1120-
checkSparkAnswerWithTol("SELECT AVG(_8) FROM tbl")
1120+
checkSparkAnswerWithTolerance("SELECT AVG(_8) FROM tbl")
11211121
assert(getNumCometHashAggregate(
11221122
sql("SELECT AVG(_8) FROM tbl")) == expectedNumOfCometAggregates)
11231123

@@ -1505,7 +1505,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
15051505
numAggregates: Int,
15061506
absTol: Double = 1e-6): Unit = {
15071507
val df = sql(query)
1508-
checkSparkAnswerWithTol(df, absTol)
1508+
checkSparkAnswerWithTolerance(df, absTol)
15091509
val actualNumAggregates = getNumCometHashAggregate(df)
15101510
assert(
15111511
actualNumAggregates == numAggregates,

spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2076,33 +2076,33 @@ class CometExecSuite extends CometTestBase {
20762076
List(s"COUNT(_$col)", s"MAX(_$col)", s"MIN(_$col)", s"SUM(_$col)")
20772077
aggregateFunctions.foreach { function =>
20782078
val df1 = sql(s"SELECT $function OVER() FROM tbl")
2079-
checkSparkAnswerWithTol(df1, 1e-6)
2079+
checkSparkAnswerWithTolerance(df1, 1e-6)
20802080

20812081
val df2 = sql(s"SELECT $function OVER(order by _2) FROM tbl")
2082-
checkSparkAnswerWithTol(df2, 1e-6)
2082+
checkSparkAnswerWithTolerance(df2, 1e-6)
20832083

20842084
val df3 = sql(s"SELECT $function OVER(order by _2 desc) FROM tbl")
2085-
checkSparkAnswerWithTol(df3, 1e-6)
2085+
checkSparkAnswerWithTolerance(df3, 1e-6)
20862086

20872087
val df4 = sql(s"SELECT $function OVER(partition by _2 order by _2) FROM tbl")
2088-
checkSparkAnswerWithTol(df4, 1e-6)
2088+
checkSparkAnswerWithTolerance(df4, 1e-6)
20892089
}
20902090
}
20912091

20922092
// SUM doesn't work for Date type. org.apache.spark.sql.AnalysisException will be thrown.
20932093
val aggregateFunctionsWithoutSum = List("COUNT(_12)", "MAX(_12)", "MIN(_12)")
20942094
aggregateFunctionsWithoutSum.foreach { function =>
20952095
val df1 = sql(s"SELECT $function OVER() FROM tbl")
2096-
checkSparkAnswerWithTol(df1, 1e-6)
2096+
checkSparkAnswerWithTolerance(df1, 1e-6)
20972097

20982098
val df2 = sql(s"SELECT $function OVER(order by _2) FROM tbl")
2099-
checkSparkAnswerWithTol(df2, 1e-6)
2099+
checkSparkAnswerWithTolerance(df2, 1e-6)
21002100

21012101
val df3 = sql(s"SELECT $function OVER(order by _2 desc) FROM tbl")
2102-
checkSparkAnswerWithTol(df3, 1e-6)
2102+
checkSparkAnswerWithTolerance(df3, 1e-6)
21032103

21042104
val df4 = sql(s"SELECT $function OVER(partition by _2 order by _2) FROM tbl")
2105-
checkSparkAnswerWithTol(df4, 1e-6)
2105+
checkSparkAnswerWithTolerance(df4, 1e-6)
21062106
}
21072107
}
21082108
}

0 commit comments

Comments
 (0)