Skip to content

Commit 7c215ed

Browse files
pepijnvealamb
andauthored
Short circuit complex case evaluation modes as soon as possible (#17898)
## Which issue does this PR close? Improvement in the context of #18075 ## Rationale for this change Speculative performance improvements for case evaluation ## What changes are included in this PR? Short circuit case evaluation loop when as soon as a value has been calculated for each input rows ## Are these changes tested? (Hopefully) covered by SQL logic tests ## Are there any user-facing changes? No --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 35b2e35 commit 7c215ed

File tree

2 files changed

+120
-28
lines changed
  • datafusion

2 files changed

+120
-28
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,15 @@ impl CaseExpr {
205205
let mut current_value = new_null_array(&return_type, batch.num_rows());
206206
// We only consider non-null values while comparing with whens
207207
let mut remainder = not(&base_nulls)?;
208+
let mut non_null_remainder_count = remainder.true_count();
208209
for i in 0..self.when_then_expr.len() {
209-
let when_value = self.when_then_expr[i]
210-
.0
211-
.evaluate_selection(batch, &remainder)?;
210+
// If there are no rows left to process, break out of the loop early
211+
if non_null_remainder_count == 0 {
212+
break;
213+
}
214+
215+
let when_predicate = &self.when_then_expr[i].0;
216+
let when_value = when_predicate.evaluate_selection(batch, &remainder)?;
212217
let when_value = when_value.into_array(batch.num_rows())?;
213218
// build boolean array representing which rows match the "when" value
214219
let when_match = compare_with_eq(
@@ -224,41 +229,46 @@ impl CaseExpr {
224229
_ => Cow::Owned(prep_null_mask_filter(&when_match)),
225230
};
226231
// Make sure we only consider rows that have not been matched yet
227-
let when_match = and(&when_match, &remainder)?;
232+
let when_value = and(&when_match, &remainder)?;
228233

229-
// When no rows available for when clause, skip then clause
230-
if when_match.true_count() == 0 {
234+
// If the predicate did not match any rows, continue to the next branch immediately
235+
let when_match_count = when_value.true_count();
236+
if when_match_count == 0 {
231237
continue;
232238
}
233239

234-
let then_value = self.when_then_expr[i]
235-
.1
236-
.evaluate_selection(batch, &when_match)?;
240+
let then_expression = &self.when_then_expr[i].1;
241+
let then_value = then_expression.evaluate_selection(batch, &when_value)?;
237242

238243
current_value = match then_value {
239244
ColumnarValue::Scalar(ScalarValue::Null) => {
240-
nullif(current_value.as_ref(), &when_match)?
245+
nullif(current_value.as_ref(), &when_value)?
241246
}
242247
ColumnarValue::Scalar(then_value) => {
243-
zip(&when_match, &then_value.to_scalar()?, &current_value)?
248+
zip(&when_value, &then_value.to_scalar()?, &current_value)?
244249
}
245250
ColumnarValue::Array(then_value) => {
246-
zip(&when_match, &then_value, &current_value)?
251+
zip(&when_value, &then_value, &current_value)?
247252
}
248253
};
249254

250-
remainder = and_not(&remainder, &when_match)?;
255+
remainder = and_not(&remainder, &when_value)?;
256+
non_null_remainder_count -= when_match_count;
251257
}
252258

253259
if let Some(e) = self.else_expr() {
254-
// keep `else_expr`'s data type and return type consistent
255-
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
256260
// null and unmatched tuples should be assigned else value
257261
remainder = or(&base_nulls, &remainder)?;
258-
let else_ = expr
259-
.evaluate_selection(batch, &remainder)?
260-
.into_array(batch.num_rows())?;
261-
current_value = zip(&remainder, &else_, &current_value)?;
262+
263+
if remainder.true_count() > 0 {
264+
// keep `else_expr`'s data type and return type consistent
265+
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
266+
267+
let else_ = expr
268+
.evaluate_selection(batch, &remainder)?
269+
.into_array(batch.num_rows())?;
270+
current_value = zip(&remainder, &else_, &current_value)?;
271+
}
262272
}
263273

264274
Ok(ColumnarValue::Array(current_value))
@@ -277,10 +287,15 @@ impl CaseExpr {
277287
// start with nulls as default output
278288
let mut current_value = new_null_array(&return_type, batch.num_rows());
279289
let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]);
290+
let mut remainder_count = batch.num_rows();
280291
for i in 0..self.when_then_expr.len() {
281-
let when_value = self.when_then_expr[i]
282-
.0
283-
.evaluate_selection(batch, &remainder)?;
292+
// If there are no rows left to process, break out of the loop early
293+
if remainder_count == 0 {
294+
break;
295+
}
296+
297+
let when_predicate = &self.when_then_expr[i].0;
298+
let when_value = when_predicate.evaluate_selection(batch, &remainder)?;
284299
let when_value = when_value.into_array(batch.num_rows())?;
285300
let when_value = as_boolean_array(&when_value).map_err(|_| {
286301
internal_datafusion_err!("WHEN expression did not return a BooleanArray")
@@ -293,14 +308,14 @@ impl CaseExpr {
293308
// Make sure we only consider rows that have not been matched yet
294309
let when_value = and(&when_value, &remainder)?;
295310

296-
// When no rows available for when clause, skip then clause
297-
if when_value.true_count() == 0 {
311+
// If the predicate did not match any rows, continue to the next branch immediately
312+
let when_match_count = when_value.true_count();
313+
if when_match_count == 0 {
298314
continue;
299315
}
300316

301-
let then_value = self.when_then_expr[i]
302-
.1
303-
.evaluate_selection(batch, &when_value)?;
317+
let then_expression = &self.when_then_expr[i].1;
318+
let then_value = then_expression.evaluate_selection(batch, &when_value)?;
304319

305320
current_value = match then_value {
306321
ColumnarValue::Scalar(ScalarValue::Null) => {
@@ -317,10 +332,11 @@ impl CaseExpr {
317332
// Succeed tuples should be filtered out for short-circuit evaluation,
318333
// null values for the current when expr should be kept
319334
remainder = and_not(&remainder, &when_value)?;
335+
remainder_count -= when_match_count;
320336
}
321337

322338
if let Some(e) = self.else_expr() {
323-
if remainder.true_count() > 0 {
339+
if remainder_count > 0 {
324340
// keep `else_expr`'s data type and return type consistent
325341
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
326342
let else_ = expr

datafusion/sqllogictest/test_files/case.slt

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,3 +519,79 @@ query I
519519
SELECT case when false then 1 / 0 else 1 / 1 end;
520520
----
521521
1
522+
523+
# Else branch evaluation with case expression, 1 when branch, null input
524+
query I
525+
SELECT CASE a WHEN 'a' THEN 0 ELSE 1 END FROM (VALUES (NULL)) t(a)
526+
----
527+
1
528+
529+
# Else branch evaluation with case expression, 2 when branches, null input
530+
query I
531+
SELECT CASE a WHEN 'a' THEN 0 WHEN 'b' THEN 1 ELSE 2 END FROM (VALUES (NULL)) t(a)
532+
----
533+
2
534+
535+
# Else branch evaluation without case expression, 1 when branch, null input
536+
query I
537+
SELECT CASE WHEN a = 'a' THEN 0 ELSE 1 END FROM (VALUES (NULL)) t(a)
538+
----
539+
1
540+
541+
# Else branch evaluation without case expression, 2 when branches, null input
542+
query I
543+
SELECT CASE WHEN a = 'a' THEN 0 WHEN a = 'b' THEN 1 ELSE 2 END FROM (VALUES (NULL)) t(a)
544+
----
545+
2
546+
547+
# Else branch evaluation with case expression, 1 when branch, non-null input
548+
query I
549+
SELECT CASE a WHEN 'a' THEN 0 ELSE 1 END FROM (VALUES ('z')) t(a)
550+
----
551+
1
552+
553+
# Else branch evaluation with case expression, 2 when branches, non-null input
554+
query I
555+
SELECT CASE a WHEN 'a' THEN 0 WHEN 'b' THEN 1 ELSE 2 END FROM (VALUES ('z')) t(a)
556+
----
557+
2
558+
559+
# Else branch evaluation without case expression, 1 when branch, non-null input
560+
query I
561+
SELECT CASE WHEN a = 'a' THEN 0 ELSE 1 END FROM (VALUES ('z')) t(a)
562+
----
563+
1
564+
565+
# Else branch evaluation without case expression, 2 when branches, non-null input
566+
query I
567+
SELECT CASE WHEN a = 'a' THEN 0 WHEN a = 'b' THEN 1 ELSE 2 END FROM (VALUES ('z')) t(a)
568+
----
569+
2
570+
571+
# Else branch evaluation with case expression, 1 when branch, mixed input
572+
query I
573+
SELECT CASE a WHEN 'a' THEN 0 ELSE 1 END FROM (VALUES (NULL), ('z')) t(a)
574+
----
575+
1
576+
1
577+
578+
# Else branch evaluation with case expression, 2 when branches, mixed input
579+
query I
580+
SELECT CASE a WHEN 'a' THEN 0 WHEN 'b' THEN 1 ELSE 2 END FROM (VALUES (NULL), ('z')) t(a)
581+
----
582+
2
583+
2
584+
585+
# Else branch evaluation without case expression, 1 when branch, mixed input
586+
query I
587+
SELECT CASE WHEN a = 'a' THEN 0 ELSE 1 END FROM (VALUES (NULL), ('z')) t(a)
588+
----
589+
1
590+
1
591+
592+
# Else branch evaluation without case expression, 2 when branches, mixed input
593+
query I
594+
SELECT CASE WHEN a = 'a' THEN 0 WHEN a = 'b' THEN 1 ELSE 2 END FROM (VALUES (NULL), ('z')) t(a)
595+
----
596+
2
597+
2

0 commit comments

Comments
 (0)