|
29 | 29 | Optional,
|
30 | 30 | )
|
31 | 31 |
|
32 |
| -import json |
33 |
| -import decimal |
34 | 32 | import datetime
|
| 33 | +import decimal |
| 34 | +import json |
35 | 35 | import warnings
|
36 | 36 | from threading import Lock
|
37 | 37 |
|
@@ -377,6 +377,52 @@ def _infer_type(cls, value: Any) -> DataType:
|
377 | 377 | def _from_value(cls, value: Any) -> "LiteralExpression":
|
378 | 378 | return LiteralExpression(value=value, dataType=LiteralExpression._infer_type(value))
|
379 | 379 |
|
| 380 | + @classmethod |
| 381 | + def _infer_type_from_literal(cls, literal: "proto.Expression.Literal") -> Optional[DataType]: |
| 382 | + if literal.HasField("null"): |
| 383 | + return NullType() |
| 384 | + elif literal.HasField("binary"): |
| 385 | + return BinaryType() |
| 386 | + elif literal.HasField("boolean"): |
| 387 | + return BooleanType() |
| 388 | + elif literal.HasField("byte"): |
| 389 | + return ByteType() |
| 390 | + elif literal.HasField("short"): |
| 391 | + return ShortType() |
| 392 | + elif literal.HasField("integer"): |
| 393 | + return IntegerType() |
| 394 | + elif literal.HasField("long"): |
| 395 | + return LongType() |
| 396 | + elif literal.HasField("float"): |
| 397 | + return FloatType() |
| 398 | + elif literal.HasField("double"): |
| 399 | + return DoubleType() |
| 400 | + elif literal.HasField("date"): |
| 401 | + return DateType() |
| 402 | + elif literal.HasField("timestamp"): |
| 403 | + return TimestampType() |
| 404 | + elif literal.HasField("timestamp_ntz"): |
| 405 | + return TimestampNTZType() |
| 406 | + elif literal.HasField("array"): |
| 407 | + if literal.array.HasField("element_type"): |
| 408 | + return ArrayType( |
| 409 | + proto_schema_to_pyspark_data_type(literal.array.element_type), True |
| 410 | + ) |
| 411 | + element_type = None |
| 412 | + if len(literal.array.elements) > 0: |
| 413 | + element_type = LiteralExpression._infer_type_from_literal(literal.array.elements[0]) |
| 414 | + |
| 415 | + if element_type is None: |
| 416 | + raise PySparkTypeError( |
| 417 | + errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE", |
| 418 | + messageParameters={}, |
| 419 | + ) |
| 420 | + return ArrayType(element_type, True) |
| 421 | + # Not all data types support inferring the data type from the literal at the moment. |
| 422 | + # e.g. the type of DayTimeInterval contains extra information like start_field and |
| 423 | + # end_field and cannot be inferred from the literal. |
| 424 | + return None |
| 425 | + |
380 | 426 | @classmethod
|
381 | 427 | def _to_value(
|
382 | 428 | cls, literal: "proto.Expression.Literal", dataType: Optional[DataType] = None
|
@@ -426,10 +472,20 @@ def _to_value(
|
426 | 472 | assert dataType is None or isinstance(dataType, DayTimeIntervalType)
|
427 | 473 | return DayTimeIntervalType().fromInternal(literal.day_time_interval)
|
428 | 474 | elif literal.HasField("array"):
|
429 |
| - elementType = proto_schema_to_pyspark_data_type(literal.array.element_type) |
430 |
| - if dataType is not None: |
431 |
| - assert isinstance(dataType, ArrayType) |
432 |
| - assert elementType == dataType.elementType |
| 475 | + elementType = None |
| 476 | + if literal.array.HasField("element_type"): |
| 477 | + elementType = proto_schema_to_pyspark_data_type(literal.array.element_type) |
| 478 | + if dataType is not None: |
| 479 | + assert isinstance(dataType, ArrayType) |
| 480 | + assert elementType == dataType.elementType |
| 481 | + elif len(literal.array.elements) > 0: |
| 482 | + elementType = LiteralExpression._infer_type_from_literal(literal.array.elements[0]) |
| 483 | + |
| 484 | + if elementType is None: |
| 485 | + raise PySparkTypeError( |
| 486 | + errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE", |
| 487 | + messageParameters={}, |
| 488 | + ) |
433 | 489 | return [LiteralExpression._to_value(v, elementType) for v in literal.array.elements]
|
434 | 490 |
|
435 | 491 | raise PySparkTypeError(
|
@@ -475,11 +531,17 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
|
475 | 531 | elif isinstance(self._dataType, DayTimeIntervalType):
|
476 | 532 | expr.literal.day_time_interval = int(self._value)
|
477 | 533 | elif isinstance(self._dataType, ArrayType):
|
478 |
| - element_type = self._dataType.elementType |
479 |
| - expr.literal.array.element_type.CopyFrom(pyspark_types_to_proto_types(element_type)) |
480 | 534 | for v in self._value:
|
481 | 535 | expr.literal.array.elements.append(
|
482 |
| - LiteralExpression(v, element_type).to_plan(session).literal |
| 536 | + LiteralExpression(v, self._dataType.elementType).to_plan(session).literal |
| 537 | + ) |
| 538 | + if ( |
| 539 | + len(self._value) == 0 |
| 540 | + or LiteralExpression._infer_type_from_literal(expr.literal.array.elements[0]) |
| 541 | + is None |
| 542 | + ): |
| 543 | + expr.literal.array.element_type.CopyFrom( |
| 544 | + pyspark_types_to_proto_types(self._dataType.elementType) |
483 | 545 | )
|
484 | 546 | else:
|
485 | 547 | raise PySparkTypeError(
|
|
0 commit comments