Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions native/proto/src/proto/datatype.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.



syntax = "proto3";

package spark.spark_expression;

option java_package = "org.apache.comet.serde";

message DataType {
enum DataTypeId {
BOOL = 0;
INT8 = 1;
INT16 = 2;
INT32 = 3;
INT64 = 4;
FLOAT = 5;
DOUBLE = 6;
STRING = 7;
BYTES = 8;
TIMESTAMP = 9;
DECIMAL = 10;
TIMESTAMP_NTZ = 11;
DATE = 12;
NULL = 13;
LIST = 14;
MAP = 15;
STRUCT = 16;
}
DataTypeId type_id = 1;

message DataTypeInfo {
oneof datatype_struct {
DecimalInfo decimal = 2;
ListInfo list = 3;
MapInfo map = 4;
StructInfo struct = 5;
}
}

message DecimalInfo {
int32 precision = 1;
int32 scale = 2;
}

message ListInfo {
DataType element_type = 1;
bool contains_null = 2;
}

message MapInfo {
DataType key_type = 1;
DataType value_type = 2;
bool value_contains_null = 3;
}

message StructInfo {
repeated string field_names = 1;
repeated DataType field_datatypes = 2;
repeated bool field_nullable = 3;
}

DataTypeInfo type_info = 2;
}
79 changes: 2 additions & 77 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ syntax = "proto3";

package spark.spark_expression;

import "datatype.proto";
import "literal.proto";
import "types.proto";

option java_package = "org.apache.comet.serde";
Expand Down Expand Up @@ -203,27 +205,6 @@ message BloomFilterAgg {
DataType datatype = 4;
}

message Literal {
oneof value {
bool bool_val = 1;
// Protobuf doesn't provide int8 and int16, we put them into int32 and convert
// to int8 and int16 when deserializing.
int32 byte_val = 2;
int32 short_val = 3;
int32 int_val = 4;
int64 long_val = 5;
float float_val = 6;
double double_val = 7;
string string_val = 8;
bytes bytes_val = 9;
bytes decimal_val = 10;
ListLiteral list_val = 11;
}

DataType datatype = 12;
bool is_null = 13;
}

enum EvalMode {
LEGACY = 0;
TRY = 1;
Expand Down Expand Up @@ -426,59 +407,3 @@ message ArrayJoin {
message Rand {
int64 seed = 1;
}

message DataType {
enum DataTypeId {
BOOL = 0;
INT8 = 1;
INT16 = 2;
INT32 = 3;
INT64 = 4;
FLOAT = 5;
DOUBLE = 6;
STRING = 7;
BYTES = 8;
TIMESTAMP = 9;
DECIMAL = 10;
TIMESTAMP_NTZ = 11;
DATE = 12;
NULL = 13;
LIST = 14;
MAP = 15;
STRUCT = 16;
}
DataTypeId type_id = 1;

message DataTypeInfo {
oneof datatype_struct {
DecimalInfo decimal = 2;
ListInfo list = 3;
MapInfo map = 4;
StructInfo struct = 5;
}
}

message DecimalInfo {
int32 precision = 1;
int32 scale = 2;
}

message ListInfo {
DataType element_type = 1;
bool contains_null = 2;
}

message MapInfo {
DataType key_type = 1;
DataType value_type = 2;
bool value_contains_null = 3;
}

message StructInfo {
repeated string field_names = 1;
repeated DataType field_datatypes = 2;
repeated bool field_nullable = 3;
}

DataTypeInfo type_info = 2;
}
46 changes: 46 additions & 0 deletions native/proto/src/proto/literal.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.



syntax = "proto3";

package spark.spark_expression;

import "datatype.proto";

option java_package = "org.apache.comet.serde";

message Literal {
oneof value {
bool bool_val = 1;
// Protobuf doesn't provide int8 and int16, we put them into int32 and convert
// to int8 and int16 when deserializing.
int32 byte_val = 2;
int32 short_val = 3;
int32 int_val = 4;
int64 long_val = 5;
float float_val = 6;
double double_val = 7;
string string_val = 8;
bytes bytes_val = 9;
bytes decimal_val = 10;
}

DataType datatype = 11;
bool is_null = 12;
}
1 change: 1 addition & 0 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ syntax = "proto3";

package spark.spark_operator;

import "datatype.proto";
import "expr.proto";
import "partitioning.proto";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ import org.apache.spark.sql.types._

import org.apache.comet.serde.ExprOuterClass
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.LiteralOuterClass
import org.apache.comet.serde.QueryPlanSerde.serializeDataType

object SourceFilterSerde extends Logging {

def createNameExpr(
name: String,
schema: StructType): Option[(DataType, ExprOuterClass.Expr)] = {
schema: StructType): Option[(org.apache.spark.sql.types.DataType, ExprOuterClass.Expr)] = {
val filedWithIndex = schema.fields.zipWithIndex.find { case (field, _) =>
field.name == name
}
Expand Down Expand Up @@ -66,8 +67,10 @@ object SourceFilterSerde extends Logging {
/**
* create a literal value native expression for source filter value, the value is a scala value
*/
def createValueExpr(value: Any, dataType: DataType): Option[ExprOuterClass.Expr] = {
val exprBuilder = ExprOuterClass.Literal.newBuilder()
def createValueExpr(
value: Any,
dataType: org.apache.spark.sql.types.DataType): Option[ExprOuterClass.Expr] = {
val exprBuilder = LiteralOuterClass.Literal.newBuilder()
var valueIsSet = true
if (value == null) {
exprBuilder.setIsNull(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo}
import org.apache.comet.DataTypeSupport.isComplexType
import org.apache.comet.expressions._
import org.apache.comet.objectstore.NativeConfig
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
import org.apache.comet.serde.ExprOuterClass.DataType._
import org.apache.comet.serde.Datatype.{DataType => ProtoDataType}
import org.apache.comet.serde.Datatype.DataType._
import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator}
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto}
import org.apache.comet.serde.Types.ListLiteral
Expand Down Expand Up @@ -213,7 +214,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
* doesn't mean it is supported by Comet native execution, i.e., `supportedDataType` may return
* false for it.
*/
def serializeDataType(dt: DataType): Option[ExprOuterClass.DataType] = {
def serializeDataType(dt: org.apache.spark.sql.types.DataType): Option[Datatype.DataType] = {
val typeId = dt match {
case _: BooleanType => 0
case _: ByteType => 1
Expand Down Expand Up @@ -728,7 +729,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
.contains(CometConf.COMET_NATIVE_SCAN_IMPL.get()) && dataType
.isInstanceOf[ArrayType]) && !isComplexType(
dataType.asInstanceOf[ArrayType].elementType)) =>
val exprBuilder = ExprOuterClass.Literal.newBuilder()
val exprBuilder = LiteralOuterClass.Literal.newBuilder()

if (value == null) {
exprBuilder.setIsNull(true)
Expand Down
4 changes: 2 additions & 2 deletions spark/src/main/scala/org/apache/comet/serde/hash.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object CometXxHash64 extends CometExpressionSerde[XxHash64] {
return None
}
val exprs = expr.children.map(exprToProtoInternal(_, inputs, binding))
val seedBuilder = ExprOuterClass.Literal
val seedBuilder = LiteralOuterClass.Literal
.newBuilder()
.setDatatype(serializeDataType(LongType).get)
.setLongVal(expr.seed)
Expand All @@ -53,7 +53,7 @@ object CometMurmur3Hash extends CometExpressionSerde[Murmur3Hash] {
return None
}
val exprs = expr.children.map(exprToProtoInternal(_, inputs, binding))
val seedBuilder = ExprOuterClass.Literal
val seedBuilder = LiteralOuterClass.Literal
.newBuilder()
.setDatatype(serializeDataType(IntegerType).get)
.setIntVal(expr.seed)
Expand Down
Loading