Skip to content

Commit 8b3b77c

Browse files
authored
feat: Support Array Literal (#2057)
* feat: support literal for ARRAY top level
1 parent cec9bf5 commit 8b3b77c

File tree

13 files changed

+468
-46
lines changed

13 files changed

+468
-46
lines changed

native/core/src/execution/planner.rs

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ use datafusion::physical_expr::window::WindowExpr;
8585
use datafusion::physical_expr::LexOrdering;
8686

8787
use crate::parquet::parquet_exec::init_datasource_exec;
88+
use arrow::array::{
89+
BinaryBuilder, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array,
90+
Int16Array, Int32Array, Int64Array, Int8Array, NullArray, StringBuilder,
91+
TimestampMicrosecondArray,
92+
};
93+
use arrow::buffer::BooleanBuffer;
94+
use datafusion::common::utils::SingleRowListArrayBuilder;
8895
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
8996
use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec;
9097
use datafusion_comet_proto::spark_operator::SparkFilePartition;
@@ -474,6 +481,125 @@ impl PhysicalPlanner {
474481
)))
475482
}
476483
}
484+
},
485+
Value::ListVal(values) => {
486+
if let DataType::List(f) = data_type {
487+
match f.data_type() {
488+
DataType::Null => {
489+
SingleRowListArrayBuilder::new(Arc::new(NullArray::new(values.clone().null_mask.len())))
490+
.build_list_scalar()
491+
}
492+
DataType::Boolean => {
493+
let vals = values.clone();
494+
SingleRowListArrayBuilder::new(Arc::new(BooleanArray::new(BooleanBuffer::from(vals.boolean_values), Some(vals.null_mask.into()))))
495+
.build_list_scalar()
496+
}
497+
DataType::Int8 => {
498+
let vals = values.clone();
499+
SingleRowListArrayBuilder::new(Arc::new(Int8Array::new(vals.byte_values.iter().map(|&x| x as i8).collect::<Vec<_>>().into(), Some(vals.null_mask.into()))))
500+
.build_list_scalar()
501+
}
502+
DataType::Int16 => {
503+
let vals = values.clone();
504+
SingleRowListArrayBuilder::new(Arc::new(Int16Array::new(vals.short_values.iter().map(|&x| x as i16).collect::<Vec<_>>().into(), Some(vals.null_mask.into()))))
505+
.build_list_scalar()
506+
}
507+
DataType::Int32 => {
508+
let vals = values.clone();
509+
SingleRowListArrayBuilder::new(Arc::new(Int32Array::new(vals.int_values.into(), Some(vals.null_mask.into()))))
510+
.build_list_scalar()
511+
}
512+
DataType::Int64 => {
513+
let vals = values.clone();
514+
SingleRowListArrayBuilder::new(Arc::new(Int64Array::new(vals.long_values.into(), Some(vals.null_mask.into()))))
515+
.build_list_scalar()
516+
}
517+
DataType::Float32 => {
518+
let vals = values.clone();
519+
SingleRowListArrayBuilder::new(Arc::new(Float32Array::new(vals.float_values.into(), Some(vals.null_mask.into()))))
520+
.build_list_scalar()
521+
}
522+
DataType::Float64 => {
523+
let vals = values.clone();
524+
SingleRowListArrayBuilder::new(Arc::new(Float64Array::new(vals.double_values.into(), Some(vals.null_mask.into()))))
525+
.build_list_scalar()
526+
}
527+
DataType::Timestamp(TimeUnit::Microsecond, None) => {
528+
let vals = values.clone();
529+
SingleRowListArrayBuilder::new(Arc::new(TimestampMicrosecondArray::new(vals.long_values.into(), Some(vals.null_mask.into()))))
530+
.build_list_scalar()
531+
}
532+
DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => {
533+
let vals = values.clone();
534+
SingleRowListArrayBuilder::new(Arc::new(TimestampMicrosecondArray::new(vals.long_values.into(), Some(vals.null_mask.into())).with_timezone(Arc::clone(tz))))
535+
.build_list_scalar()
536+
}
537+
DataType::Date32 => {
538+
let vals = values.clone();
539+
SingleRowListArrayBuilder::new(Arc::new(Date32Array::new(vals.int_values.into(), Some(vals.null_mask.into()))))
540+
.build_list_scalar()
541+
}
542+
DataType::Binary => {
543+
// Using a builder as it is cumbersome to create BinaryArray from a vector with nulls
544+
// and calculate correct offsets
545+
let vals = values.clone();
546+
let item_capacity = vals.string_values.len();
547+
let data_capacity = vals.string_values.first().map(|s| s.len() * item_capacity).unwrap_or(0);
548+
let mut arr = BinaryBuilder::with_capacity(item_capacity, data_capacity);
549+
550+
for (i, v) in vals.bytes_values.into_iter().enumerate() {
551+
if vals.null_mask[i] {
552+
arr.append_value(v);
553+
} else {
554+
arr.append_null();
555+
}
556+
}
557+
558+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
559+
.build_list_scalar()
560+
}
561+
DataType::Utf8 => {
562+
// Using a builder as it is cumbersome to create StringArray from a vector with nulls
563+
// and calculate correct offsets
564+
let vals = values.clone();
565+
let item_capacity = vals.string_values.len();
566+
let data_capacity = vals.string_values.first().map(|s| s.len() * item_capacity).unwrap_or(0);
567+
let mut arr = StringBuilder::with_capacity(item_capacity, data_capacity);
568+
569+
for (i, v) in vals.string_values.into_iter().enumerate() {
570+
if vals.null_mask[i] {
571+
arr.append_value(v);
572+
} else {
573+
arr.append_null();
574+
}
575+
}
576+
577+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
578+
.build_list_scalar()
579+
}
580+
DataType::Decimal128(p, s) => {
581+
let vals = values.clone();
582+
SingleRowListArrayBuilder::new(Arc::new(Decimal128Array::new(vals.decimal_values.into_iter().map(|v| {
583+
let big_integer = BigInt::from_signed_bytes_be(&v);
584+
big_integer.to_i128().ok_or_else(|| {
585+
GeneralError(format!(
586+
"Cannot parse {big_integer:?} as i128 for Decimal literal"
587+
))
588+
}).unwrap()
589+
}).collect::<Vec<_>>().into(), Some(vals.null_mask.into())).with_precision_and_scale(*p, *s)?)).build_list_scalar()
590+
}
591+
dt => {
592+
return Err(GeneralError(format!(
593+
"DataType::List literal does not support {dt:?} type"
594+
)))
595+
}
596+
}
597+
598+
} else {
599+
return Err(GeneralError(format!(
600+
"Expected DataType::List but got {data_type:?}"
601+
)))
602+
}
477603
}
478604
}
479605
};
@@ -1300,6 +1426,7 @@ impl PhysicalPlanner {
13001426
// The `ScanExec` operator will take actual arrays from Spark during execution
13011427
let scan =
13021428
ScanExec::new(self.exec_context_id, input_source, &scan.source, data_types)?;
1429+
13031430
Ok((
13041431
vec![scan.clone()],
13051432
Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(scan), vec![])),
@@ -2322,7 +2449,6 @@ impl PhysicalPlanner {
23222449
other => other,
23232450
};
23242451
let func = self.session_ctx.udf(fun_name)?;
2325-
23262452
let coerced_types = func
23272453
.coerce_types(&input_expr_types)
23282454
.unwrap_or_else(|_| input_expr_types.clone());

native/proto/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
// Include generated modules from .proto files.
2323
#[allow(missing_docs)]
24+
#[allow(clippy::large_enum_variant)]
2425
pub mod spark_expression {
2526
include!(concat!("generated", "/spark.spark_expression.rs"));
2627
}

native/proto/src/proto/expr.proto

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ syntax = "proto3";
2121

2222
package spark.spark_expression;
2323

24+
import "types.proto";
25+
2426
option java_package = "org.apache.comet.serde";
2527

2628
// The basic message representing a Spark expression.
@@ -112,13 +114,13 @@ enum StatisticsType {
112114
}
113115

114116
message Count {
115-
repeated Expr children = 1;
117+
repeated Expr children = 1;
116118
}
117119

118120
message Sum {
119-
Expr child = 1;
120-
DataType datatype = 2;
121-
bool fail_on_error = 3;
121+
Expr child = 1;
122+
DataType datatype = 2;
123+
bool fail_on_error = 3;
122124
}
123125

124126
message Min {
@@ -215,10 +217,11 @@ message Literal {
215217
string string_val = 8;
216218
bytes bytes_val = 9;
217219
bytes decimal_val = 10;
218-
}
220+
ListLiteral list_val = 11;
221+
}
219222

220-
DataType datatype = 11;
221-
bool is_null = 12;
223+
DataType datatype = 12;
224+
bool is_null = 13;
222225
}
223226

224227
enum EvalMode {
@@ -478,5 +481,4 @@ message DataType {
478481
}
479482

480483
DataTypeInfo type_info = 2;
481-
}
482-
484+
}

native/proto/src/proto/types.proto

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
19+
20+
syntax = "proto3";
21+
22+
package spark.spark_expression;
23+
24+
option java_package = "org.apache.comet.serde";
25+
26+
message ListLiteral {
27+
// Only one of these fields should be populated based on the array type
28+
repeated bool boolean_values = 1;
29+
repeated int32 byte_values = 2;
30+
repeated int32 short_values = 3;
31+
repeated int32 int_values = 4;
32+
repeated int64 long_values = 5;
33+
repeated float float_values = 6;
34+
repeated double double_values = 7;
35+
repeated string string_values = 8;
36+
repeated bytes bytes_values = 9;
37+
repeated bytes decimal_values = 10;
38+
repeated ListLiteral list_values = 11;
39+
40+
repeated bool null_mask = 12;
41+
}

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use crate::utils::array_with_timezone;
2020
use crate::{EvalMode, SparkError, SparkResult};
2121
use arrow::array::builder::StringBuilder;
2222
use arrow::array::{DictionaryArray, StringArray, StructArray};
23+
use arrow::compute::can_cast_types;
2324
use arrow::datatypes::{DataType, Schema};
2425
use arrow::{
2526
array::{
@@ -968,6 +969,9 @@ fn cast_array(
968969
to_type,
969970
cast_options,
970971
)?),
972+
(List(_), List(_)) if can_cast_types(from_type, to_type) => {
973+
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
974+
}
971975
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
972976
if cast_options.allow_cast_unsigned_ints =>
973977
{
@@ -1018,7 +1022,7 @@ fn is_datafusion_spark_compatible(
10181022
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
10191023
// note that the cast from Int32/Int64 -> Decimal128 here is actually
10201024
// not compatible with Spark (no overflow checks) but we have tests that
1021-
// rely on this cast working so we have to leave it here for now
1025+
// rely on this cast working, so we have to leave it here for now
10221026
matches!(
10231027
to_type,
10241028
DataType::Boolean

spark/src/main/scala/org/apache/comet/DataTypeSupport.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,9 @@ object DataTypeSupport {
7373
val ARRAY_ELEMENT = "array element"
7474
val MAP_KEY = "map key"
7575
val MAP_VALUE = "map value"
76+
77+
def isComplexType(dt: DataType): Boolean = dt match {
78+
case _: StructType | _: ArrayType | _: MapType => true
79+
case _ => false
80+
}
7681
}

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
package org.apache.comet.expressions
2121

22-
import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructType}
22+
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType}
2323

2424
sealed trait SupportLevel
2525

@@ -62,6 +62,9 @@ object CometCast {
6262
}
6363

6464
(fromType, toType) match {
65+
case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible()
66+
case (dt: ArrayType, dt1: ArrayType) =>
67+
isSupported(dt.elementType, dt1.elementType, timeZoneId, evalMode)
6568
case (dt: DataType, _) if dt.typeName == "timestamp_ntz" =>
6669
// https://github.com/apache/datafusion-comet/issues/378
6770
toType match {

spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import org.apache.spark.sql.types._
3737
import org.apache.comet.{CometConf, CometSparkSessionExtensions, DataTypeSupport}
3838
import org.apache.comet.CometConf._
3939
import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, isCometScanEnabled, withInfo, withInfos}
40+
import org.apache.comet.DataTypeSupport.isComplexType
4041
import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
4142

4243
/**
@@ -277,11 +278,6 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] {
277278
val partitionSchemaSupported =
278279
typeChecker.isSchemaSupported(partitionSchema, fallbackReasons)
279280

280-
def isComplexType(dt: DataType): Boolean = dt match {
281-
case _: StructType | _: ArrayType | _: MapType => true
282-
case _ => false
283-
}
284-
285281
def hasMapsContainingStructs(dataType: DataType): Boolean = {
286282
dataType match {
287283
case s: StructType => s.exists(field => hasMapsContainingStructs(field.dataType))

0 commit comments

Comments
 (0)