Skip to content

[SPARK-52837][CONNECT][PYTHON] Support TimeType literal in Connect #51515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
DecimalType,
StringType,
DataType,
TimeType,
TimestampType,
TimestampNTZType,
DayTimeIntervalType,
Expand Down Expand Up @@ -248,6 +249,7 @@ def __init__(self, value: Any, dataType: DataType) -> None:
DecimalType,
StringType,
DateType,
TimeType,
TimestampType,
TimestampNTZType,
DayTimeIntervalType,
Expand Down Expand Up @@ -298,6 +300,9 @@ def __init__(self, value: Any, dataType: DataType) -> None:
value = DateType().toInternal(value)
else:
value = DateType().toInternal(value.date())
elif isinstance(dataType, TimeType):
assert isinstance(value, datetime.time)
value = TimeType().toInternal(value)
elif isinstance(dataType, TimestampType):
assert isinstance(value, datetime.datetime)
value = TimestampType().toInternal(value)
Expand Down Expand Up @@ -352,6 +357,8 @@ def _infer_type(cls, value: Any) -> DataType:
return TimestampNTZType() if is_timestamp_ntz_preferred() else TimestampType()
elif isinstance(value, datetime.date):
return DateType()
elif isinstance(value, datetime.time):
return TimeType()
elif isinstance(value, datetime.timedelta):
return DayTimeIntervalType()
elif isinstance(value, np.generic):
Expand Down Expand Up @@ -416,6 +423,9 @@ def _to_value(
elif literal.HasField("date"):
assert dataType is None or isinstance(dataType, DataType)
return DateType().fromInternal(literal.date)
elif literal.HasField("time"):
assert dataType is None or isinstance(dataType, TimeType)
return TimeType().fromInternal(literal.time.nano)
elif literal.HasField("timestamp"):
assert dataType is None or isinstance(dataType, TimestampType)
return TimestampType().fromInternal(literal.timestamp)
Expand Down Expand Up @@ -468,6 +478,9 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr.literal.string = str(self._value)
elif isinstance(self._dataType, DateType):
expr.literal.date = int(self._value)
elif isinstance(self._dataType, TimeType):
expr.literal.time.precision = self._dataType.precision
expr.literal.time.nano = int(self._value)
elif isinstance(self._dataType, TimestampType):
expr.literal.timestamp = int(self._value)
elif isinstance(self._dataType, TimestampNTZType):
Expand Down Expand Up @@ -496,6 +509,10 @@ def __repr__(self) -> str:
dt = DateType().fromInternal(self._value)
if dt is not None and isinstance(dt, datetime.date):
return dt.strftime("%Y-%m-%d")
elif isinstance(self._dataType, TimeType):
t = TimeType().fromInternal(self._value)
if t is not None and isinstance(t, datetime.time):
return t.strftime("%H:%M:%S.%f")
elif isinstance(self._dataType, TimestampType):
ts = TimestampType().fromInternal(self._value)
if ts is not None and isinstance(ts, datetime.datetime):
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/connect/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
IntegerType,
FloatType,
DateType,
TimeType,
TimestampType,
TimestampNTZType,
DayTimeIntervalType,
Expand Down Expand Up @@ -151,6 +152,8 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
ret.decimal.precision = data_type.precision
elif isinstance(data_type, DateType):
ret.date.CopyFrom(pb2.DataType.Date())
elif isinstance(data_type, TimeType):
ret.time.precision = data_type.precision
elif isinstance(data_type, TimestampType):
ret.timestamp.CopyFrom(pb2.DataType.Timestamp())
elif isinstance(data_type, TimestampNTZType):
Expand Down Expand Up @@ -237,6 +240,8 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType:
return VarcharType(schema.var_char.length)
elif schema.HasField("date"):
return DateType()
elif schema.HasField("time"):
return TimeType()
elif schema.HasField("timestamp"):
return TimestampType()
elif schema.HasField("timestamp_ntz"):
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
StringType,
BinaryType,
DateType,
TimeType,
TimestampType,
TimestampNTZType,
DayTimeIntervalType,
Expand Down Expand Up @@ -302,6 +303,8 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da
spark_type = BinaryType()
elif types.is_date32(at):
spark_type = DateType()
elif types.is_time(at):
spark_type = TimeType()
elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None:
spark_type = TimestampNTZType()
elif types.is_timestamp(at):
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,16 @@ def condition():

eventually(catch_assertions=True)(condition)()

def test_time_lit(self) -> None:
# SPARK-52779: Test TimeType lit
ndf = self.connect.range(1).select(CF.lit(datetime.time(12, 13, 14)))
df = self.spark.sql("SELECT TIME '12:13:14'")

self.assert_eq(
ndf.toPandas(),
df.toPandas(),
)


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
MapType,
NullType,
DateType,
TimeType,
TimestampType,
TimestampNTZType,
ByteType,
Expand Down Expand Up @@ -396,6 +397,7 @@ def test_literal_with_acceptable_type(self):
("sss", StringType()),
(datetime.date(2022, 12, 13), DateType()),
(datetime.datetime.now(), DateType()),
(datetime.time(1, 0, 0), TimeType()),
(datetime.datetime.now(), TimestampType()),
(datetime.datetime.now(), TimestampNTZType()),
(datetime.timedelta(1, 2, 3), DayTimeIntervalType()),
Expand Down Expand Up @@ -441,6 +443,7 @@ def test_literal_null(self):
DoubleType(),
DecimalType(),
DateType(),
TimeType(),
TimestampType(),
TimestampNTZType(),
DayTimeIntervalType(),
Expand Down
7 changes: 6 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ def test_float_nan_inf(self):
self.assertIsNotNone(inf_lit.to_plan(None))

def test_datetime_literal_types(self):
"""Test the different timestamp, date, and timedelta types."""
"""Test the different timestamp, date, time, and timedelta types."""
datetime_lit = lit(datetime.datetime.now())

p = datetime_lit.to_plan(None)
Expand All @@ -908,6 +908,10 @@ def test_datetime_literal_types(self):
# (24 * 3600 + 2) * 1000000 + 3
self.assertEqual(86402000003, time_delta.to_plan(None).literal.day_time_interval)

time_lit = lit(datetime.time(23, 59, 59, 999999))
self.assertIsNotNone(time_lit.to_plan(None))
self.assertEqual(time_lit.to_plan(None).literal.time.nano, 86399999999000)

def test_list_to_literal(self):
"""Test conversion of lists to literals"""
empty_list = []
Expand Down Expand Up @@ -1024,6 +1028,7 @@ def test_literal_to_any_conversion(self):
decimal.Decimal(1.234567),
"sss",
datetime.date(2022, 12, 13),
datetime.time(12, 13, 14),
datetime.datetime.now(),
datetime.timedelta(1, 2, 3),
[1, 2, 3, 4, 5, 6],
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,11 @@ def test_shiftrightunsigned(self):
)
).collect()

def test_lit_time(self):
t = datetime.time(12, 34, 56)
actual = self.spark.range(1).select(F.lit(t)).first()[0]
self.assertEqual(actual, t)

def test_lit_day_time_interval(self):
td = datetime.timedelta(days=1, hours=12, milliseconds=123)
actual = self.spark.range(1).select(F.lit(td)).first()[0]
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ def test_nested_dataframe(self):
self.assertEqual(df3.take(1), [Row(id=4)])
self.assertEqual(df3.tail(1), [Row(id=9)])

def test_lit_time(self):
import datetime
actual = self.spark.sql("select TIME '12:34:56'").first()[0]
self.assertEqual(actual, datetime.time(12, 34, 56))

class SQLTests(SQLTestsMixin, ReusedSQLTestCase):
pass
Expand Down
26 changes: 21 additions & 5 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
IntegerType,
FloatType,
DateType,
TimeType,
TimestampType,
TimestampNTZType,
DayTimeIntervalType,
Expand Down Expand Up @@ -525,7 +526,7 @@ def test_create_dataframe_from_objects(self):
self.assertEqual(df.first(), Row(key=1, value="1"))

def test_apply_schema(self):
from datetime import date, datetime, timedelta
from datetime import date, time, datetime, timedelta

rdd = self.sc.parallelize(
[
Expand All @@ -537,6 +538,7 @@ def test_apply_schema(self):
2147483647,
1.0,
date(2010, 1, 1),
time(23, 23, 59, 999999),
datetime(2010, 1, 1, 1, 1, 1),
timedelta(days=1),
{"a": 1},
Expand All @@ -555,6 +557,7 @@ def test_apply_schema(self):
StructField("int1", IntegerType(), False),
StructField("float1", FloatType(), False),
StructField("date1", DateType(), False),
StructField("time", TimeType(), False),
StructField("time1", TimestampType(), False),
StructField("daytime1", DayTimeIntervalType(), False),
StructField("map1", MapType(StringType(), IntegerType(), False), False),
Expand All @@ -573,6 +576,7 @@ def test_apply_schema(self):
x.int1,
x.float1,
x.date1,
x.time,
x.time1,
x.daytime1,
x.map1["a"],
Expand All @@ -589,6 +593,7 @@ def test_apply_schema(self):
2147483647,
1.0,
date(2010, 1, 1),
time(23, 23, 59, 999999),
datetime(2010, 1, 1, 1, 1, 1),
timedelta(days=1),
1,
Expand Down Expand Up @@ -1241,6 +1246,7 @@ def test_parse_datatype_json_string(self):
IntegerType(),
LongType(),
DateType(),
TimeType(),
TimestampType(),
TimestampNTZType(),
NullType(),
Expand Down Expand Up @@ -1291,6 +1297,7 @@ def test_parse_datatype_string(self):
_parse_datatype_string("a INT, c DOUBLE"),
)
self.assertEqual(VariantType(), _parse_datatype_string("variant"))
self.assertEqual(TimeType(5), _parse_datatype_string("time(5)"))

def test_tree_string(self):
schema1 = DataType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>")
Expand Down Expand Up @@ -1543,6 +1550,7 @@ def test_tree_string_for_builtin_types(self):
.add("bin", BinaryType())
.add("bool", BooleanType())
.add("date", DateType())
.add("time", TimeType())
.add("ts", TimestampType())
.add("ts_ntz", TimestampNTZType())
.add("dec", DecimalType(10, 2))
Expand Down Expand Up @@ -1578,6 +1586,7 @@ def test_tree_string_for_builtin_types(self):
" |-- bin: binary (nullable = true)",
" |-- bool: boolean (nullable = true)",
" |-- date: date (nullable = true)",
" |-- time: time(6) (nullable = true)",
" |-- ts: timestamp (nullable = true)",
" |-- ts_ntz: timestamp_ntz (nullable = true)",
" |-- dec: decimal(10,2) (nullable = true)",
Expand Down Expand Up @@ -1925,6 +1934,7 @@ def test_repr(self):
BinaryType(),
BooleanType(),
DateType(),
TimeType(),
TimestampType(),
DecimalType(),
DoubleType(),
Expand Down Expand Up @@ -2332,8 +2342,8 @@ def test_to_ddl(self):
schema = StructType().add("a", ArrayType(DoubleType()), False).add("b", DateType())
self.assertEqual(schema.toDDL(), "a ARRAY<DOUBLE> NOT NULL,b DATE")

schema = StructType().add("a", TimestampType()).add("b", TimestampNTZType())
self.assertEqual(schema.toDDL(), "a TIMESTAMP,b TIMESTAMP_NTZ")
schema = StructType().add("a", TimestampType()).add("b", TimestampNTZType()).add("c", TimeType())
self.assertEqual(schema.toDDL(), "a TIMESTAMP,b TIMESTAMP_NTZ,c TIME(6)")

def test_from_ddl(self):
self.assertEqual(DataType.fromDDL("long"), LongType())
Expand All @@ -2349,6 +2359,10 @@ def test_from_ddl(self):
DataType.fromDDL("a int, v variant"),
StructType([StructField("a", IntegerType()), StructField("v", VariantType())]),
)
self.assertEqual(
DataType.fromDDL("a time(6)"),
StructType([StructField("a", TimeType(6))]),
)

# Ensures that changing the implementation of `DataType.fromDDL` in PR #47253 does not change
# `fromDDL`'s behavior.
Expand Down Expand Up @@ -2602,8 +2616,9 @@ def __init__(self, **kwargs):
(decimal.Decimal("1.0"), DecimalType()),
# Binary
(bytearray([1, 2]), BinaryType()),
# Date/Timestamp
# Date/Time/Timestamp
(datetime.date(2000, 1, 2), DateType()),
(datetime.time(1, 0, 0), TimeType()),
(datetime.datetime(2000, 1, 2, 3, 4), DateType()),
(datetime.datetime(2000, 1, 2, 3, 4), TimestampType()),
# Array
Expand Down Expand Up @@ -2666,8 +2681,9 @@ def __init__(self, **kwargs):
("1.0", DecimalType(), TypeError),
# Binary
(1, BinaryType(), TypeError),
# Date/Timestamp
# Date/Time/Timestamp
("2000-01-02", DateType(), TypeError),
("23:59:59", TimeType(), TypeError),
(946811040, TimestampType(), TypeError),
# Array
(["1", None], ArrayType(StringType(), containsNull=False), ValueError),
Expand Down
Loading