diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 655559dc9d..7fa7bfe905 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -85,6 +85,13 @@ use datafusion::physical_expr::window::WindowExpr; use datafusion::physical_expr::LexOrdering; use crate::parquet::parquet_exec::init_datasource_exec; +use arrow::array::{ + BinaryBuilder, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, NullArray, StringBuilder, + TimestampMicrosecondArray, +}; +use arrow::buffer::BooleanBuffer; +use datafusion::common::utils::SingleRowListArrayBuilder; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec; use datafusion_comet_proto::spark_operator::SparkFilePartition; @@ -474,6 +481,125 @@ impl PhysicalPlanner { ))) } } + }, + Value::ListVal(values) => { + if let DataType::List(f) = data_type { + match f.data_type() { + DataType::Null => { + SingleRowListArrayBuilder::new(Arc::new(NullArray::new(values.clone().null_mask.len()))) + .build_list_scalar() + } + DataType::Boolean => { + let vals = values.clone(); + SingleRowListArrayBuilder::new(Arc::new(BooleanArray::new(BooleanBuffer::from(vals.boolean_values), Some(vals.null_mask.into())))) + .build_list_scalar() + } + DataType::Int8 => { + let vals = values.clone(); + SingleRowListArrayBuilder::new(Arc::new(Int8Array::new(vals.byte_values.iter().map(|&x| x as i8).collect::>().into(), Some(vals.null_mask.into())))) + .build_list_scalar() + } + DataType::Int16 => { + let vals = values.clone(); + SingleRowListArrayBuilder::new(Arc::new(Int16Array::new(vals.short_values.iter().map(|&x| x as i16).collect::>().into(), Some(vals.null_mask.into())))) + .build_list_scalar() + } + DataType::Int32 => { + let vals = values.clone(); + SingleRowListArrayBuilder::new(Arc::new(Int32Array::new(vals.int_values.into(), Some(vals.null_mask.into())))) + .build_list_scalar() + } + DataType::Int64 => { + let vals = values.clone(); + SingleRowListArrayBuilder::new(Arc::new(Int64Array::new(vals.long_values.into(), Some(vals.null_mask.into())))) + .build_list_scalar() + } + DataType::Float32 => { + let vals = values.clone(); + SingleRowListArrayBuilder::new(Arc::new(Float32Array::new(vals.float_values.into(), Some(vals.null_mask.into())))) + .build_list_scalar() + } + DataType::Float64 => { + let vals = values.clone(); + SingleRowListArrayBuilder::new(Arc::new(Float64Array::new(vals.double_values.into(), Some(vals.null_mask.into())))) + .build_list_scalar() + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { + let vals = values.clone(); + SingleRowListArrayBuilder::new(Arc::new(TimestampMicrosecondArray::new(vals.long_values.into(), Some(vals.null_mask.into())))) + .build_list_scalar() + } + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { + let vals = values.clone(); + SingleRowListArrayBuilder::new(Arc::new(TimestampMicrosecondArray::new(vals.long_values.into(), Some(vals.null_mask.into())).with_timezone(Arc::clone(tz)))) + .build_list_scalar() + } + DataType::Date32 => { + let vals = values.clone(); + SingleRowListArrayBuilder::new(Arc::new(Date32Array::new(vals.int_values.into(), Some(vals.null_mask.into())))) + .build_list_scalar() + } + DataType::Binary => { + // Using a builder as it is cumbersome to create BinaryArray from a vector with nulls + // and calculate correct offsets + let vals = values.clone(); + let item_capacity = vals.string_values.len(); + let data_capacity = vals.string_values.first().map(|s| s.len() * item_capacity).unwrap_or(0); + let mut arr = BinaryBuilder::with_capacity(item_capacity, data_capacity); + + for (i, v) in vals.bytes_values.into_iter().enumerate() { + if vals.null_mask[i] { + arr.append_value(v); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Utf8 => { + // Using a builder as it is cumbersome to create StringArray from a vector with nulls + // and calculate correct offsets + let vals = values.clone(); + let item_capacity = vals.string_values.len(); + let data_capacity = vals.string_values.first().map(|s| s.len() * item_capacity).unwrap_or(0); + let mut arr = StringBuilder::with_capacity(item_capacity, data_capacity); + + for (i, v) in vals.string_values.into_iter().enumerate() { + if vals.null_mask[i] { + arr.append_value(v); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Decimal128(p, s) => { + let vals = values.clone(); + SingleRowListArrayBuilder::new(Arc::new(Decimal128Array::new(vals.decimal_values.into_iter().map(|v| { + let big_integer = BigInt::from_signed_bytes_be(&v); + big_integer.to_i128().ok_or_else(|| { + GeneralError(format!( + "Cannot parse {big_integer:?} as i128 for Decimal literal" + )) + }).unwrap() + }).collect::>().into(), Some(vals.null_mask.into())).with_precision_and_scale(*p, *s)?)).build_list_scalar() + } + dt => { + return Err(GeneralError(format!( + "DataType::List literal does not support {dt:?} type" + ))) + } + } + + } else { + return Err(GeneralError(format!( + "Expected DataType::List but got {data_type:?}" + ))) + } } } }; @@ -1300,6 +1426,7 @@ impl PhysicalPlanner { // The `ScanExec` operator will take actual arrays from Spark during execution let scan = ScanExec::new(self.exec_context_id, input_source, &scan.source, data_types)?; + Ok(( vec![scan.clone()], Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(scan), vec![])), @@ -2322,7 +2449,6 @@ impl PhysicalPlanner { other => other, }; let func = self.session_ctx.udf(fun_name)?; - let coerced_types = func .coerce_types(&input_expr_types) .unwrap_or_else(|_| input_expr_types.clone()); diff --git a/native/proto/src/lib.rs b/native/proto/src/lib.rs index ed24440360..2c213c2514 100644 --- a/native/proto/src/lib.rs +++ b/native/proto/src/lib.rs @@ -21,6 +21,7 @@ // Include generated modules from .proto files. #[allow(missing_docs)] +#[allow(clippy::large_enum_variant)] pub mod spark_expression { include!(concat!("generated", "/spark.spark_expression.rs")); } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 8b193ba846..1152d7a1b2 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -21,6 +21,8 @@ syntax = "proto3"; package spark.spark_expression; +import "types.proto"; + option java_package = "org.apache.comet.serde"; // The basic message representing a Spark expression. @@ -112,13 +114,13 @@ enum StatisticsType { } message Count { - repeated Expr children = 1; + repeated Expr children = 1; } message Sum { - Expr child = 1; - DataType datatype = 2; - bool fail_on_error = 3; + Expr child = 1; + DataType datatype = 2; + bool fail_on_error = 3; } message Min { @@ -215,10 +217,11 @@ message Literal { string string_val = 8; bytes bytes_val = 9; bytes decimal_val = 10; - } + ListLiteral list_val = 11; + } - DataType datatype = 11; - bool is_null = 12; + DataType datatype = 12; + bool is_null = 13; } enum EvalMode { @@ -478,5 +481,4 @@ message DataType { } DataTypeInfo type_info = 2; -} - +} \ No newline at end of file diff --git a/native/proto/src/proto/types.proto b/native/proto/src/proto/types.proto new file mode 100644 index 0000000000..cc163522b4 --- /dev/null +++ b/native/proto/src/proto/types.proto @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + + + +syntax = "proto3"; + +package spark.spark_expression; + +option java_package = "org.apache.comet.serde"; + +message ListLiteral { + // Only one of these fields should be populated based on the array type + repeated bool boolean_values = 1; + repeated int32 byte_values = 2; + repeated int32 short_values = 3; + repeated int32 int_values = 4; + repeated int64 long_values = 5; + repeated float float_values = 6; + repeated double double_values = 7; + repeated string string_values = 8; + repeated bytes bytes_values = 9; + repeated bytes decimal_values = 10; + repeated ListLiteral list_values = 11; + + repeated bool null_mask = 12; +} \ No newline at end of file diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 7140d2be09..1cf061ab23 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -20,6 +20,7 @@ use crate::utils::array_with_timezone; use crate::{EvalMode, SparkError, SparkResult}; use arrow::array::builder::StringBuilder; use arrow::array::{DictionaryArray, StringArray, StructArray}; +use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Schema}; use arrow::{ array::{ @@ -968,6 +969,9 @@ fn cast_array( to_type, cast_options, )?), + (List(_), List(_)) if can_cast_types(from_type, to_type) => { + Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) + } (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64) if cast_options.allow_cast_unsigned_ints => { @@ -1018,7 +1022,7 @@ fn is_datafusion_spark_compatible( DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { // note that the cast from Int32/Int64 -> Decimal128 here is actually // not compatible with Spark (no overflow checks) but we have tests that - // rely on this cast working so we have to leave it here for now + // rely on this cast working, so we have to leave it here for now matches!( to_type, DataType::Boolean diff --git a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala index 917aa95697..66eab8c9f0 100644 --- a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala +++ b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala @@ -73,4 +73,9 @@ object DataTypeSupport { val ARRAY_ELEMENT = "array element" val MAP_KEY = "map key" val MAP_VALUE = "map value" + + def isComplexType(dt: DataType): Boolean = dt match { + case _: StructType | _: ArrayType | _: MapType => true + case _ => false + } } diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index e0e89b35fc..337eae11db 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -19,7 +19,7 @@ package org.apache.comet.expressions -import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType} sealed trait SupportLevel @@ -62,6 +62,9 @@ object CometCast { } (fromType, toType) match { + case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible() + case (dt: ArrayType, dt1: ArrayType) => + isSupported(dt.elementType, dt1.elementType, timeZoneId, evalMode) case (dt: DataType, _) if dt.typeName == "timestamp_ntz" => // https://github.com/apache/datafusion-comet/issues/378 toType match { diff --git a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala index 592069fcc6..6a328f4be2 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.types._ import org.apache.comet.{CometConf, CometSparkSessionExtensions, DataTypeSupport} import org.apache.comet.CometConf._ import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, isCometScanEnabled, withInfo, withInfos} +import org.apache.comet.DataTypeSupport.isComplexType import org.apache.comet.parquet.{CometParquetScan, SupportsComet} /** @@ -277,11 +278,6 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] { val partitionSchemaSupported = typeChecker.isSchemaSupported(partitionSchema, fallbackReasons) - def isComplexType(dt: DataType): Boolean = dt match { - case _: StructType | _: ArrayType | _: MapType => true - case _ => false - } - def hasMapsContainingStructs(dataType: DataType): Boolean = { dataType match { case s: StructType => s.exists(field => hasMapsContainingStructs(field.dataType)) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 1b72521270..3455d85c93 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils +import org.apache.spark.sql.catalyst.util.{CharVarcharCodegenUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec @@ -47,14 +47,18 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import com.google.protobuf.ByteString + import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo} +import org.apache.comet.DataTypeSupport.isComplexType import org.apache.comet.expressions._ import org.apache.comet.objectstore.NativeConfig import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType._ import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator} import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} +import org.apache.comet.serde.Types.ListLiteral import org.apache.comet.shims.CometExprShim /** @@ -711,8 +715,53 @@ object QueryPlanSerde extends Logging with CometExprShim { binding, (builder, binaryExpr) => builder.setNeqNullSafe(binaryExpr)) + case GreaterThan(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setGt(binaryExpr)) + + case GreaterThanOrEqual(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setGtEq(binaryExpr)) + + case LessThan(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setLt(binaryExpr)) + + case LessThanOrEqual(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setLtEq(binaryExpr)) + case Literal(value, dataType) - if supportedDataType(dataType, allowComplex = value == null) => + if supportedDataType( + dataType, + allowComplex = value == null || + // Nested literal support for native reader + // can be tracked https://github.com/apache/datafusion-comet/issues/1937 + // now supports only Array of primitive + (Seq(CometConf.SCAN_NATIVE_ICEBERG_COMPAT, CometConf.SCAN_NATIVE_DATAFUSION) + .contains(CometConf.COMET_NATIVE_SCAN_IMPL.get()) && dataType + .isInstanceOf[ArrayType]) && !isComplexType( + dataType.asInstanceOf[ArrayType].elementType)) => val exprBuilder = ExprOuterClass.Literal.newBuilder() if (value == null) { @@ -723,14 +772,13 @@ object QueryPlanSerde extends Logging with CometExprShim { case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean]) case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte]) case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short]) - case _: IntegerType => exprBuilder.setIntVal(value.asInstanceOf[Int]) - case _: LongType => exprBuilder.setLongVal(value.asInstanceOf[Long]) + case _: IntegerType | _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int]) + case _: LongType | _: TimestampType | _: TimestampNTZType => + exprBuilder.setLongVal(value.asInstanceOf[Long]) case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float]) case _: DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf[Double]) case _: StringType => exprBuilder.setStringVal(value.asInstanceOf[UTF8String].toString) - case _: TimestampType => exprBuilder.setLongVal(value.asInstanceOf[Long]) - case _: TimestampNTZType => exprBuilder.setLongVal(value.asInstanceOf[Long]) case _: DecimalType => // Pass decimal literal as bytes. val unscaled = value.asInstanceOf[Decimal].toBigDecimal.underlying.unscaledValue @@ -740,7 +788,87 @@ object QueryPlanSerde extends Logging with CometExprShim { val byteStr = com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]]) exprBuilder.setBytesVal(byteStr) - case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int]) + case a: ArrayType => + val listLiteralBuilder = ListLiteral.newBuilder() + val array = value.asInstanceOf[GenericArrayData].array + a.elementType match { + case NullType => + array.foreach(_ => listLiteralBuilder.addNullMask(true)) + case BooleanType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Boolean] + listLiteralBuilder.addBooleanValues(casted) + listLiteralBuilder.addNullMask(casted != null) + }) + case ByteType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Integer] + listLiteralBuilder.addByteValues(casted) + listLiteralBuilder.addNullMask(casted != null) + }) + case ShortType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Short] + listLiteralBuilder.addShortValues( + if (casted != null) casted.intValue() + else null.asInstanceOf[java.lang.Integer]) + listLiteralBuilder.addNullMask(casted != null) + }) + case IntegerType | DateType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Integer] + listLiteralBuilder.addIntValues(casted) + listLiteralBuilder.addNullMask(casted != null) + }) + case LongType | TimestampType | TimestampNTZType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Long] + listLiteralBuilder.addLongValues(casted) + listLiteralBuilder.addNullMask(casted != null) + }) + case FloatType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Float] + listLiteralBuilder.addFloatValues(casted) + listLiteralBuilder.addNullMask(casted != null) + }) + case DoubleType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Double] + listLiteralBuilder.addDoubleValues(casted) + listLiteralBuilder.addNullMask(casted != null) + }) + case StringType => + array.foreach(v => { + val casted = v.asInstanceOf[org.apache.spark.unsafe.types.UTF8String] + listLiteralBuilder.addStringValues( + if (casted != null) casted.toString else "") + listLiteralBuilder.addNullMask(casted != null) + }) + case _: DecimalType => + array + .foreach(v => { + val casted = + v.asInstanceOf[Decimal] + listLiteralBuilder.addDecimalValues(if (casted != null) { + com.google.protobuf.ByteString + .copyFrom(casted.toBigDecimal.underlying.unscaledValue.toByteArray) + } else ByteString.EMPTY) + listLiteralBuilder.addNullMask(casted != null) + }) + case _: BinaryType => + array + .foreach(v => { + val casted = + v.asInstanceOf[Array[Byte]] + listLiteralBuilder.addBytesValues(if (casted != null) { + com.google.protobuf.ByteString.copyFrom(casted) + } else ByteString.EMPTY) + listLiteralBuilder.addNullMask(casted != null) + }) + } + exprBuilder.setListVal(listLiteralBuilder.build()) + exprBuilder.setDatatype(serializeDataType(dataType).get) case dt => logWarning(s"Unexpected datatype '$dt' for literal value '$value'") } diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 9951f4f9d0..9976ecd748 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark40Plus} +import org.apache.comet.DataTypeSupport.isComplexType import org.apache.comet.serde.CometArrayExcept import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} @@ -272,8 +273,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } - // https://github.com/apache/datafusion-comet/issues/1929 - ignore("array_contains - array literals") { + test("array_contains - array literals") { withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") val filename = path.toString @@ -292,14 +292,13 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp generateMap = false)) } val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t2") for (field <- table.schema.fields) { val typeName = field.dataType.typeName - checkSparkAnswerAndOperator( - sql(s"SELECT array_contains(cast(null as array<$typeName>), b) FROM t2")) checkSparkAnswerAndOperator(sql( - s"SELECT array_contains(cast(array() as array<$typeName>), cast(null as $typeName)) FROM t2")) - checkSparkAnswerAndOperator(sql("SELECT array_contains(array(), 1) FROM t2")) + s"SELECT array_contains(cast(null as array<$typeName>), cast(null as $typeName)) FROM t2")) } + checkSparkAnswerAndOperator(sql("SELECT array_contains(array(), 1) FROM t2")) } } diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala index a1b1812b31..217cd322dd 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.types._ +import org.apache.comet.DataTypeSupport.isComplexType import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala index f33da3ba71..8f1e7cfdf0 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.Tag import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions.{array, col} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -253,18 +254,11 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper } test("native reader - read a STRUCT subfield - field from second") { - withSQLConf( - CometConf.COMET_EXEC_ENABLED.key -> "true", - SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "false", - CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") { - testSingleLineQuery( - """ + testSingleLineQuery( + """ |select 1 a, named_struct('a', 1, 'b', 'n') c0 |""".stripMargin, - "select c0.b from tbl") - } + "select c0.b from tbl") } test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - field from first") { @@ -436,4 +430,131 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper |""".stripMargin, "select c0['key1'].b from tbl") } + + test("native reader - support ARRAY literal INT fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(1, 2, null, 3, null) from tbl") + } + + test("native reader - support ARRAY literal BOOL fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(true, null, false, null) from tbl") + } + + test("native reader - support ARRAY literal NULL fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(null) from tbl") + } + + test("native reader - support empty ARRAY literal") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array() from tbl") + } + + test("native reader - support ARRAY literal BYTE fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(1, 2, null, 3, null) from tbl") + } + + test("native reader - support ARRAY literal SHORT fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as short), cast(2 as short), null, cast(3 as short), null) from tbl") + } + + test("native reader - support ARRAY literal DATE fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(CAST('2024-01-01' AS DATE), CAST('2024-02-01' AS DATE), null, CAST('2024-03-01' AS DATE), null) from tbl") + } + + test("native reader - support ARRAY literal LONG fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as bigint), cast(2 as bigint), null, cast(3 as bigint), null) from tbl") + } + + test("native reader - support ARRAY literal TIMESTAMP fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(CAST('2024-01-01 10:00:00' AS TIMESTAMP), CAST('2024-01-02 10:00:00' AS TIMESTAMP), null, CAST('2024-01-03 10:00:00' AS TIMESTAMP), null) from tbl") + } + + test("native reader - support ARRAY literal TIMESTAMP TZ fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(CAST('2024-01-01 10:00:00' AS TIMESTAMP_NTZ), CAST('2024-01-02 10:00:00' AS TIMESTAMP_NTZ), null, CAST('2024-01-03 10:00:00' AS TIMESTAMP_NTZ), null) from tbl") + } + + test("native reader - support ARRAY literal FLOAT fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as float), cast(2 as float), null, cast(3 as float), null) from tbl") + } + + test("native reader - support ARRAY literal DOUBLE fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as double), cast(2 as double), null, cast(3 as double), null) from tbl") + } + + test("native reader - support ARRAY literal STRING fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array('a', 'bc', null, 'def', null) from tbl") + } + + test("native reader - support ARRAY literal DECIMAL fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as decimal(10, 2)), cast(2.5 as decimal(10, 2)), null, cast(3.75 as decimal(10, 2)), null) from tbl") + } + + test("native reader - support ARRAY literal BINARY fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast('a' as binary), cast('bc' as binary), null, cast('def' as binary), null) from tbl") + } + + test("SPARK-18053: ARRAY equality is broken") { + withTable("array_tbl") { + spark.range(10).select(array(col("id")).as("arr")).write.saveAsTable("array_tbl") + assert(sql("SELECT * FROM array_tbl where arr = ARRAY(1L)").count == 1) + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 365003aa8c..cf11bdf590 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal._ import org.apache.spark.sql.test._ -import org.apache.spark.sql.types.{ArrayType, DataType, DecimalType, MapType, StructType} +import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.comet._ import org.apache.comet.shims.ShimCometSparkSessionExtensions @@ -1142,9 +1142,4 @@ abstract class CometTestBase usingDataSourceExec(conf) && !CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.get(conf) } - - def isComplexType(dt: DataType): Boolean = dt match { - case _: StructType | _: ArrayType | _: MapType => true - case _ => false - } }