diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 872770ee22911..4ddf13757db41 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -54,6 +54,7 @@ DecimalType, StringType, DataType, + TimeType, TimestampType, TimestampNTZType, DayTimeIntervalType, @@ -248,6 +249,7 @@ def __init__(self, value: Any, dataType: DataType) -> None: DecimalType, StringType, DateType, + TimeType, TimestampType, TimestampNTZType, DayTimeIntervalType, @@ -298,6 +300,9 @@ def __init__(self, value: Any, dataType: DataType) -> None: value = DateType().toInternal(value) else: value = DateType().toInternal(value.date()) + elif isinstance(dataType, TimeType): + assert isinstance(value, datetime.time) + value = TimeType().toInternal(value) elif isinstance(dataType, TimestampType): assert isinstance(value, datetime.datetime) value = TimestampType().toInternal(value) @@ -352,6 +357,8 @@ def _infer_type(cls, value: Any) -> DataType: return TimestampNTZType() if is_timestamp_ntz_preferred() else TimestampType() elif isinstance(value, datetime.date): return DateType() + elif isinstance(value, datetime.time): + return TimeType() elif isinstance(value, datetime.timedelta): return DayTimeIntervalType() elif isinstance(value, np.generic): @@ -416,6 +423,9 @@ def _to_value( elif literal.HasField("date"): assert dataType is None or isinstance(dataType, DataType) return DateType().fromInternal(literal.date) + elif literal.HasField("time"): + assert dataType is None or isinstance(dataType, TimeType) + return TimeType().fromInternal(literal.time.nano) elif literal.HasField("timestamp"): assert dataType is None or isinstance(dataType, TimestampType) return TimestampType().fromInternal(literal.timestamp) @@ -468,6 +478,9 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": expr.literal.string = str(self._value) elif isinstance(self._dataType, DateType): expr.literal.date = int(self._value) + elif isinstance(self._dataType, TimeType): + expr.literal.time.precision = self._dataType.precision + expr.literal.time.nano = int(self._value) elif isinstance(self._dataType, TimestampType): expr.literal.timestamp = int(self._value) elif isinstance(self._dataType, TimestampNTZType): @@ -496,6 +509,10 @@ def __repr__(self) -> str: dt = DateType().fromInternal(self._value) if dt is not None and isinstance(dt, datetime.date): return dt.strftime("%Y-%m-%d") + elif isinstance(self._dataType, TimeType): + t = TimeType().fromInternal(self._value) + if t is not None and isinstance(t, datetime.time): + return t.strftime("%H:%M:%S.%f") elif isinstance(self._dataType, TimestampType): ts = TimestampType().fromInternal(self._value) if ts is not None and isinstance(ts, datetime.datetime): diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index c2eb5f4e017f0..2fcf1102ee68d 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -29,6 +29,7 @@ IntegerType, FloatType, DateType, + TimeType, TimestampType, TimestampNTZType, DayTimeIntervalType, @@ -151,6 +152,8 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: ret.decimal.precision = data_type.precision elif isinstance(data_type, DateType): ret.date.CopyFrom(pb2.DataType.Date()) + elif isinstance(data_type, TimeType): + ret.time.precision = data_type.precision elif isinstance(data_type, TimestampType): ret.timestamp.CopyFrom(pb2.DataType.Timestamp()) elif isinstance(data_type, TimestampNTZType): @@ -237,6 +240,8 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType: return VarcharType(schema.var_char.length) elif schema.HasField("date"): return DateType() + elif schema.HasField("time"): + return TimeType(schema.time.precision) if schema.time.HasField("precision") else TimeType() elif schema.HasField("timestamp"): return TimestampType() elif schema.HasField("timestamp_ntz"): diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 2f82609c84292..771819d3c9ba6 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -37,6 +37,7 @@ StringType, BinaryType, DateType, + TimeType, TimestampType, TimestampNTZType, DayTimeIntervalType, @@ -302,6 +303,8 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da spark_type = BinaryType() elif types.is_date32(at): spark_type = DateType() + elif types.is_time(at): + spark_type = TimeType() elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None: spark_type = TimestampNTZType() elif types.is_timestamp(at): diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 2aa383f39937e..e52d29d76cfa3 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -1508,6 +1508,16 @@ def condition(): eventually(catch_assertions=True)(condition)() + def test_time_lit(self) -> None: + # SPARK-52779: Test TimeType lit + ndf = self.connect.range(1).select(CF.lit(datetime.time(12, 13, 14))) + df = self.spark.sql("SELECT TIME '12:13:14'") + + self.assert_eq( + ndf.toPandas(), + df.toPandas(), + ) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401 diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 4873006fbbb90..8983d45d42d14 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -25,6 +25,7 @@ MapType, NullType, DateType, + TimeType, TimestampType, TimestampNTZType, ByteType, @@ -396,6 +397,7 @@ def test_literal_with_acceptable_type(self): ("sss", StringType()), (datetime.date(2022, 12, 13), DateType()), (datetime.datetime.now(), DateType()), + (datetime.time(1, 0, 0), TimeType()), (datetime.datetime.now(), TimestampType()), (datetime.datetime.now(), TimestampNTZType()), (datetime.timedelta(1, 2, 3), DayTimeIntervalType()), @@ -441,6 +443,7 @@ def test_literal_null(self): DoubleType(), DecimalType(), DateType(), + TimeType(), TimestampType(), TimestampNTZType(), DayTimeIntervalType(), diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py index a03cd30c733fb..d25799f0c9f26 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan.py @@ -893,7 +893,7 @@ def test_float_nan_inf(self): self.assertIsNotNone(inf_lit.to_plan(None)) def test_datetime_literal_types(self): - """Test the different timestamp, date, and timedelta types.""" + """Test the different timestamp, date, time, and timedelta types.""" datetime_lit = lit(datetime.datetime.now()) p = datetime_lit.to_plan(None) @@ -908,6 +908,10 @@ def test_datetime_literal_types(self): # (24 * 3600 + 2) * 1000000 + 3 self.assertEqual(86402000003, time_delta.to_plan(None).literal.day_time_interval) + time_lit = lit(datetime.time(23, 59, 59, 999999)) + self.assertIsNotNone(time_lit.to_plan(None)) + self.assertEqual(time_lit.to_plan(None).literal.time.nano, 86399999999000) + def test_list_to_literal(self): """Test conversion of lists to literals""" empty_list = [] @@ -1024,6 +1028,7 @@ def test_literal_to_any_conversion(self): decimal.Decimal(1.234567), "sss", datetime.date(2022, 12, 13), + datetime.time(12, 13, 14), datetime.datetime.now(), datetime.timedelta(1, 2, 3), [1, 2, 3, 4, 5, 6], diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 5f1991973d27d..888ab3c531539 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -361,6 +361,9 @@ def test_lit_time_representation(self): ts = datetime.datetime(2021, 3, 4, 12, 34, 56, 1234) self.assertEqual(str(sf.lit(ts)), "Column<'2021-03-04 12:34:56.001234'>") + ts = datetime.time(12, 34, 56, 1234) + self.assertEqual(str(sf.lit(ts)), "Column<'12:34:56.001234'>") + @unittest.skipIf(not have_pandas, pandas_requirement_message) def test_lit_delta_representation(self): for delta in [ diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index e2b3e33756ba3..906fdc673024f 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -1344,6 +1344,11 @@ def test_shiftrightunsigned(self): ) ).collect() + def test_lit_time(self): + t = datetime.time(12, 34, 56) + actual = self.spark.range(1).select(F.lit(t)).first()[0] + self.assertEqual(actual, t) + def test_lit_day_time_interval(self): td = datetime.timedelta(days=1, hours=12, milliseconds=123) actual = self.spark.range(1).select(F.lit(td)).first()[0] diff --git a/python/pyspark/sql/tests/test_sql.py b/python/pyspark/sql/tests/test_sql.py index bf50bbc11ac33..e60ad183d1474 100644 --- a/python/pyspark/sql/tests/test_sql.py +++ b/python/pyspark/sql/tests/test_sql.py @@ -168,6 +168,12 @@ def test_nested_dataframe(self): self.assertEqual(df3.take(1), [Row(id=4)]) self.assertEqual(df3.tail(1), [Row(id=9)]) + def test_lit_time(self): + import datetime + + actual = self.spark.sql("select TIME '12:34:56'").first()[0] + self.assertEqual(actual, datetime.time(12, 34, 56)) + class SQLTests(SQLTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 185198766b794..72bd6b24c2509 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -42,6 +42,7 @@ IntegerType, FloatType, DateType, + TimeType, TimestampType, TimestampNTZType, DayTimeIntervalType, @@ -525,7 +526,7 @@ def test_create_dataframe_from_objects(self): self.assertEqual(df.first(), Row(key=1, value="1")) def test_apply_schema(self): - from datetime import date, datetime, timedelta + from datetime import date, time, datetime, timedelta rdd = self.sc.parallelize( [ @@ -537,6 +538,7 @@ def test_apply_schema(self): 2147483647, 1.0, date(2010, 1, 1), + time(23, 23, 59, 999999), datetime(2010, 1, 1, 1, 1, 1), timedelta(days=1), {"a": 1}, @@ -555,6 +557,7 @@ def test_apply_schema(self): StructField("int1", IntegerType(), False), StructField("float1", FloatType(), False), StructField("date1", DateType(), False), + StructField("time", TimeType(), False), StructField("time1", TimestampType(), False), StructField("daytime1", DayTimeIntervalType(), False), StructField("map1", MapType(StringType(), IntegerType(), False), False), @@ -573,6 +576,7 @@ def test_apply_schema(self): x.int1, x.float1, x.date1, + x.time, x.time1, x.daytime1, x.map1["a"], @@ -589,6 +593,7 @@ def test_apply_schema(self): 2147483647, 1.0, date(2010, 1, 1), + time(23, 23, 59, 999999), datetime(2010, 1, 1, 1, 1, 1), timedelta(days=1), 1, @@ -1241,6 +1246,7 @@ def test_parse_datatype_json_string(self): IntegerType(), LongType(), DateType(), + TimeType(5), TimestampType(), TimestampNTZType(), NullType(), @@ -1291,6 +1297,8 @@ def test_parse_datatype_string(self): _parse_datatype_string("a INT, c DOUBLE"), ) self.assertEqual(VariantType(), _parse_datatype_string("variant")) + self.assertEqual(TimeType(5), _parse_datatype_string("time(5)")) + self.assertEqual(TimeType(), _parse_datatype_string("time( 6 )")) def test_tree_string(self): schema1 = DataType.fromDDL("c1 INT, c2 STRUCT>") @@ -1543,6 +1551,7 @@ def test_tree_string_for_builtin_types(self): .add("bin", BinaryType()) .add("bool", BooleanType()) .add("date", DateType()) + .add("time", TimeType()) .add("ts", TimestampType()) .add("ts_ntz", TimestampNTZType()) .add("dec", DecimalType(10, 2)) @@ -1578,6 +1587,7 @@ def test_tree_string_for_builtin_types(self): " |-- bin: binary (nullable = true)", " |-- bool: boolean (nullable = true)", " |-- date: date (nullable = true)", + " |-- time: time(6) (nullable = true)", " |-- ts: timestamp (nullable = true)", " |-- ts_ntz: timestamp_ntz (nullable = true)", " |-- dec: decimal(10,2) (nullable = true)", @@ -1925,6 +1935,7 @@ def test_repr(self): BinaryType(), BooleanType(), DateType(), + TimeType(), TimestampType(), DecimalType(), DoubleType(), @@ -2332,8 +2343,10 @@ def test_to_ddl(self): schema = StructType().add("a", ArrayType(DoubleType()), False).add("b", DateType()) self.assertEqual(schema.toDDL(), "a ARRAY NOT NULL,b DATE") - schema = StructType().add("a", TimestampType()).add("b", TimestampNTZType()) - self.assertEqual(schema.toDDL(), "a TIMESTAMP,b TIMESTAMP_NTZ") + schema = ( + StructType().add("a", TimestampType()).add("b", TimestampNTZType()).add("c", TimeType()) + ) + self.assertEqual(schema.toDDL(), "a TIMESTAMP,b TIMESTAMP_NTZ,c TIME(6)") def test_from_ddl(self): self.assertEqual(DataType.fromDDL("long"), LongType()) @@ -2349,6 +2362,10 @@ def test_from_ddl(self): DataType.fromDDL("a int, v variant"), StructType([StructField("a", IntegerType()), StructField("v", VariantType())]), ) + self.assertEqual( + DataType.fromDDL("a time(6)"), + StructType([StructField("a", TimeType(6))]), + ) # Ensures that changing the implementation of `DataType.fromDDL` in PR #47253 does not change # `fromDDL`'s behavior. @@ -2602,8 +2619,9 @@ def __init__(self, **kwargs): (decimal.Decimal("1.0"), DecimalType()), # Binary (bytearray([1, 2]), BinaryType()), - # Date/Timestamp + # Date/Time/Timestamp (datetime.date(2000, 1, 2), DateType()), + (datetime.time(1, 0, 0), TimeType()), (datetime.datetime(2000, 1, 2, 3, 4), DateType()), (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()), # Array @@ -2666,8 +2684,9 @@ def __init__(self, **kwargs): ("1.0", DecimalType(), TypeError), # Binary (1, BinaryType(), TypeError), - # Date/Timestamp + # Date/Time/Timestamp ("2000-01-02", DateType(), TypeError), + ("23:59:59", TimeType(), TypeError), (946811040, TimestampType(), TypeError), # Array (["1", None], ArrayType(StringType(), containsNull=False), ValueError), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 744815b751ed2..5dc8e4c78283a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -80,6 +80,7 @@ "BinaryType", "BooleanType", "DateType", + "TimeType", "TimestampType", "TimestampNTZType", "DecimalType", @@ -215,6 +216,7 @@ def _get_jvm_type_name(cls, dataType: "DataType") -> str: VarcharType, DayTimeIntervalType, YearMonthIntervalType, + TimeType, ), ): return dataType.simpleString() @@ -367,14 +369,18 @@ class BooleanType(AtomicType, metaclass=DataTypeSingleton): pass -class DateType(AtomicType, metaclass=DataTypeSingleton): - """Date (datetime.date) data type.""" - - EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() +class DatetimeType(AtomicType): + """Super class of all datetime data type.""" def needConversion(self) -> bool: return True + +class DateType(DatetimeType, metaclass=DataTypeSingleton): + """Date (datetime.date) data type.""" + + EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() + def toInternal(self, d: datetime.date) -> int: if d is not None: return d.toordinal() - self.EPOCH_ORDINAL @@ -384,11 +390,47 @@ def fromInternal(self, v: int) -> datetime.date: return datetime.date.fromordinal(v + self.EPOCH_ORDINAL) -class TimestampType(AtomicType, metaclass=DataTypeSingleton): - """Timestamp (datetime.datetime) data type.""" +class AnyTimeType(DatetimeType): + """A TIME type of any valid precision.""" - def needConversion(self) -> bool: - return True + pass + + +class TimeType(AnyTimeType): + """Time (datetime.time) data type.""" + + def __init__(self, precision: int = 6): + self.precision = precision + + def toInternal(self, t: datetime.time) -> int: + if t is not None: + return ( + t.hour * 3_600_000_000_000 + + t.minute * 60_000_000_000 + + t.second * 1_000_000_000 + + t.microsecond * 1_000 + ) + + def fromInternal(self, nano: int) -> datetime.time: + if nano is not None: + hours, remainder = divmod(nano, 3_600_000_000_000) + minutes, remainder = divmod(remainder, 60_000_000_000) + seconds, remainder = divmod(remainder, 1_000_000_000) + microseconds = remainder // 1_000 + return datetime.time(hours, minutes, seconds, microseconds) + + def simpleString(self) -> str: + return "time(%d)" % (self.precision) + + def jsonValue(self) -> str: + return "time(%d)" % (self.precision) + + def __repr__(self) -> str: + return "TimeType(%d)" % (self.precision) + + +class TimestampType(DatetimeType, metaclass=DataTypeSingleton): + """Timestamp (datetime.datetime) data type.""" def toInternal(self, dt: datetime.datetime) -> int: if dt is not None: @@ -403,12 +445,9 @@ def fromInternal(self, ts: int) -> datetime.datetime: return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000) -class TimestampNTZType(AtomicType, metaclass=DataTypeSingleton): +class TimestampNTZType(DatetimeType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information.""" - def needConversion(self) -> bool: - return True - @classmethod def typeName(cls) -> str: return "timestamp_ntz" @@ -1846,6 +1885,7 @@ def parseJson(cls, json_str: str) -> "VariantVal": _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)") _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?") _INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?") +_TIME = re.compile(r"time\(\s*(\d+)\s*\)") _COLLATIONS_METADATA_KEY = "__COLLATIONS" @@ -1987,6 +2027,9 @@ def _parse_datatype_json_value( elif _FIXED_DECIMAL.match(json_value): m = _FIXED_DECIMAL.match(json_value) return DecimalType(int(m.group(1)), int(m.group(2))) # type: ignore[union-attr] + elif _TIME.match(json_value): + m = _TIME.match(json_value) + return TimeType(int(m.group(1))) # type: ignore[union-attr] elif _INTERVAL_DAYTIME.match(json_value): m = _INTERVAL_DAYTIME.match(json_value) inverted_fields = DayTimeIntervalType._inverted_fields @@ -2603,6 +2646,7 @@ def convert_struct(obj: Any) -> Optional[Tuple]: VarcharType: (str,), BinaryType: (bytearray, bytes), DateType: (datetime.date, datetime.datetime), + TimeType: (datetime.time,), TimestampType: (datetime.datetime,), TimestampNTZType: (datetime.datetime,), DayTimeIntervalType: (datetime.timedelta,), @@ -3209,6 +3253,17 @@ def convert(self, obj: datetime.date, gateway_client: "GatewayClient") -> "JavaG return Date.valueOf(obj.strftime("%Y-%m-%d")) +class TimeConverter: + def can_convert(self, obj: Any) -> bool: + return isinstance(obj, datetime.time) + + def convert(self, obj: datetime.time, gateway_client: "GatewayClient") -> "JavaGateway": + from py4j.java_gateway import JavaClass + + LocalTime = JavaClass("java.time.LocalTime", gateway_client) + return LocalTime.of(obj.hour, obj.minute, obj.second, obj.microsecond * 1000) + + class DatetimeConverter: def can_convert(self, obj: Any) -> bool: return isinstance(obj, datetime.datetime) @@ -3337,6 +3392,7 @@ def convert(self, obj: "np.ndarray", gateway_client: "GatewayClient") -> "JavaGa register_input_converter(DatetimeNTZConverter()) register_input_converter(DatetimeConverter()) register_input_converter(DateConverter()) + register_input_converter(TimeConverter()) register_input_converter(DayTimeIntervalTypeConverter()) register_input_converter(NumpyScalarConverter()) # NumPy array satisfies py4j.java_collections.ListConverter, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index fd7ccb2189bff..eb6fad8d1a3c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -36,7 +36,8 @@ import org.apache.spark.unsafe.types.{UTF8String, VariantVal} object EvaluatePython { def needConversionInPython(dt: DataType): Boolean = dt match { - case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType => true + case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType + | _: TimeType => true case _: StructType => true case _: UserDefinedType[_] => true case ArrayType(elementType, _) => needConversionInPython(elementType) @@ -138,7 +139,7 @@ object EvaluatePython { case c: Int => c } - case TimestampType | TimestampNTZType | _: DayTimeIntervalType => (obj: Any) => + case TimestampType | TimestampNTZType | _: DayTimeIntervalType | _: TimeType => (obj: Any) => nullSafeConvert(obj) { case c: Long => c // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs