diff --git a/piccolo/apps/migrations/auto/serialisation.py b/piccolo/apps/migrations/auto/serialisation.py index b3644b853..872864080 100644 --- a/piccolo/apps/migrations/auto/serialisation.py +++ b/piccolo/apps/migrations/auto/serialisation.py @@ -12,12 +12,22 @@ from dataclasses import dataclass, field from enum import Enum -from piccolo.columns import Column +from piccolo.columns import Column, Timestamptz from piccolo.columns.defaults.base import Default +from piccolo.columns.defaults.timestamptz import ( + TimestamptzCustom, + TimestamptzNow, + TimestamptzOffset, +) from piccolo.columns.reference import LazyTableReference from piccolo.table import Table from piccolo.utils.repr import repr_class_instance +try: + from zoneinfo import ZoneInfo # type: ignore +except ImportError: # pragma: no cover + from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 + from .serialisation_legacy import deserialise_legacy_params ############################################################################### @@ -546,6 +556,30 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: expect_conflict_with_global_name=UniqueGlobalNames.DEFAULT, ) ) + # ZoneInfo for Timestamptz* instances + in_group = ( + Timestamptz, + TimestamptzNow, + TimestamptzCustom, + TimestamptzOffset, + ) + if isinstance(value, in_group): + extra_imports.append( + Import( + module=ZoneInfo.__module__, + target=None, + ) + ) + continue + + # ZoneInfo instances + if isinstance(value, ZoneInfo): + extra_imports.append( + Import( + module=value.__class__.__module__, + target=None, + ) + ) continue # Dates and times @@ -633,6 +667,7 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: extra_imports.append( Import(module=module_name, target=type_.__name__) ) + continue # Functions if inspect.isfunction(value): diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 2afcfb741..770b88c5a 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -64,6 +64,11 @@ class Band(Table): from piccolo.utils.encoding import dump_json from piccolo.utils.warnings import colored_warning +try: + from zoneinfo import ZoneInfo # type: ignore +except ImportError: # pragma: no cover + from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 + if t.TYPE_CHECKING: # pragma: no cover from piccolo.columns.base import ColumnMeta from piccolo.table import Table @@ -955,36 +960,40 @@ def __set__(self, obj, value: t.Union[datetime, None]): class Timestamptz(Column): """ Used for storing timezone aware datetimes. Uses the ``datetime`` type for - values. The values are converted to UTC in the database, and are also - returned as UTC. + values. The values are converted to UTC when saved into the database and + are converted back into the timezone of the column on select queries. **Example** .. code-block:: python import datetime + from zoneinfo import ZoneInfo - class Concert(Table): - starts = Timestamptz() + class TallinnConcerts(Table): + event_start = Timestamptz(tz=ZoneInfo("Europe/Tallinn")) # Create - >>> await Concert( - ... starts=datetime.datetime( - ... year=2050, month=1, day=1, tzinfo=datetime.timezone.tz + >>> await TallinnConcerts( + ... event_start=datetime.datetime( + ... year=2050, month=1, day=1, hour=20 ... ) ... ).save() # Query - >>> await Concert.select(Concert.starts) + >>> await TallinnConcerts.select(TallinnConcerts.event_start) { - 'starts': datetime.datetime( - 2050, 1, 1, 0, 0, tzinfo=datetime.timezone.utc + 'event_start': datetime.datetime( + 2050, 1, 1, 20, 0, tzinfo=zoneinfo.ZoneInfo( + key='Europe/Tallinn' + ) ) } """ value_type = datetime + tz_type = ZoneInfo # Currently just used by ModelBuilder, to know that we want a timezone # aware datetime. @@ -993,20 +1002,24 @@ class Concert(Table): timedelta_delegate = TimedeltaDelegate() def __init__( - self, default: TimestamptzArg = TimestamptzNow(), **kwargs + self, + tz: ZoneInfo = ZoneInfo("UTC"), + default: TimestamptzArg = TimestamptzNow(), + **kwargs, ) -> None: self._validate_default( default, TimestamptzArg.__args__ # type: ignore ) if isinstance(default, datetime): - default = TimestamptzCustom.from_datetime(default) + default = TimestamptzCustom.from_datetime(default, tz) if default == datetime.now: - default = TimestamptzNow() + default = TimestamptzNow(tz) + self.tz = tz self.default = default - kwargs.update({"default": default}) + kwargs.update({"tz": tz, "default": default}) super().__init__(**kwargs) ########################################################################### diff --git a/piccolo/columns/defaults/timestamptz.py b/piccolo/columns/defaults/timestamptz.py index 5db6ebd54..90ba0fa9f 100644 --- a/piccolo/columns/defaults/timestamptz.py +++ b/piccolo/columns/defaults/timestamptz.py @@ -1,13 +1,31 @@ from __future__ import annotations -import datetime +import datetime as pydatetime import typing as t from enum import Enum +try: + from zoneinfo import ZoneInfo # type: ignore +except ImportError: # pragma: no cover + from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 + from .timestamp import TimestampCustom, TimestampNow, TimestampOffset class TimestamptzOffset(TimestampOffset): + def __init__( + self, + days: int = 0, + hours: int = 0, + minutes: int = 0, + seconds: int = 0, + tz: ZoneInfo = ZoneInfo("UTC"), + ): + self.tz = tz + super().__init__( + days=days, hours=hours, minutes=minutes, seconds=seconds + ) + @property def cockroach(self): interval_string = self.get_postgres_interval_string( @@ -16,9 +34,7 @@ def cockroach(self): return f"CURRENT_TIMESTAMP + INTERVAL '{interval_string}'" def python(self): - return datetime.datetime.now( - tz=datetime.timezone.utc - ) + datetime.timedelta( + return pydatetime.datetime.now(tz=self.tz) + pydatetime.timedelta( days=self.days, hours=self.hours, minutes=self.minutes, @@ -27,35 +43,60 @@ def python(self): class TimestamptzNow(TimestampNow): + def __init__(self, tz: ZoneInfo = ZoneInfo("UTC")): + self.tz = tz + @property def cockroach(self): return "current_timestamp" def python(self): - return datetime.datetime.now(tz=datetime.timezone.utc) + return pydatetime.datetime.now(tz=self.tz) class TimestamptzCustom(TimestampCustom): + def __init__( + self, + year: int = 2000, + month: int = 1, + day: int = 1, + hour: int = 0, + second: int = 0, + microsecond: int = 0, + tz: ZoneInfo = ZoneInfo("UTC"), + ): + self.tz = tz + super().__init__( + year=year, + month=month, + day=day, + hour=hour, + second=second, + microsecond=microsecond, + ) + @property def cockroach(self): return "'{}'".format(self.datetime.isoformat().replace("T", " ")) @property def datetime(self): - return datetime.datetime( + return pydatetime.datetime( year=self.year, month=self.month, day=self.day, hour=self.hour, second=self.second, microsecond=self.microsecond, - tzinfo=datetime.timezone.utc, + tzinfo=self.tz, ) @classmethod - def from_datetime(cls, instance: datetime.datetime): # type: ignore + def from_datetime( + cls, instance: pydatetime.datetime, tz: ZoneInfo = ZoneInfo("UTC") + ): # type: ignore if instance.tzinfo is not None: - instance = instance.astimezone(datetime.timezone.utc) + instance = instance.astimezone(tz) return cls( year=instance.year, month=instance.month, @@ -72,7 +113,7 @@ def from_datetime(cls, instance: datetime.datetime): # type: ignore TimestamptzOffset, Enum, None, - datetime.datetime, + pydatetime.datetime, ] diff --git a/piccolo/table.py b/piccolo/table.py index b4fcbf942..2dc3e1cae 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -6,6 +6,7 @@ import typing as t import warnings from dataclasses import dataclass, field +from datetime import datetime from piccolo.columns import Column from piccolo.columns.column_types import ( @@ -17,6 +18,7 @@ ReferencedTable, Secret, Serial, + Timestamptz, ) from piccolo.columns.defaults.base import Default from piccolo.columns.indexes import IndexMethod @@ -436,6 +438,9 @@ def __init__( ): raise ValueError(f"{column._meta.name} wasn't provided") + if isinstance(column, Timestamptz) and isinstance(value, datetime): + value = value.astimezone(column.tz) + self[column._meta.name] = value unrecognized = kwargs.keys() diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 0a5ee6244..9199ab8e9 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -5,3 +5,5 @@ targ>=0.3.7 inflection>=0.5.1 typing-extensions>=4.3.0 pydantic[email]==2.* +tzdata>=2024.1 +backports.zoneinfo>=0.2.1; python_version <= '3.8' diff --git a/tests/columns/test_timestamptz.py b/tests/columns/test_timestamptz.py index 8e239900b..7d8ee3da3 100644 --- a/tests/columns/test_timestamptz.py +++ b/tests/columns/test_timestamptz.py @@ -1,8 +1,8 @@ import datetime +import time +from operator import eq from unittest import TestCase -from dateutil import tz - from piccolo.columns.column_types import Timestamptz from piccolo.columns.defaults.timestamptz import ( TimestamptzCustom, @@ -11,9 +11,19 @@ ) from piccolo.table import Table +try: + from zoneinfo import ZoneInfo # type: ignore +except ImportError: # pragma: no cover + from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 + + +UTC_TZ = ZoneInfo("UTC") +LOCAL_TZ = ZoneInfo("Europe/Tallinn") + class MyTable(Table): - created_on = Timestamptz() + created_on_utc = Timestamptz(tz=UTC_TZ) + created_on_local = Timestamptz(tz=LOCAL_TZ) class MyTableDefault(Table): @@ -22,18 +32,19 @@ class MyTableDefault(Table): `Timestamptz`. """ - created_on = Timestamptz(default=TimestamptzNow()) - created_on_offset = Timestamptz(default=TimestamptzOffset(days=1)) - created_on_custom = Timestamptz(default=TimestamptzCustom(year=2021)) + created_on = Timestamptz(default=TimestamptzNow(tz=LOCAL_TZ), tz=LOCAL_TZ) + created_on_offset = Timestamptz( + default=TimestamptzOffset(days=1, tz=LOCAL_TZ), tz=LOCAL_TZ + ) + created_on_custom = Timestamptz( + default=TimestamptzCustom(year=2021, tz=LOCAL_TZ), tz=LOCAL_TZ + ) created_on_datetime = Timestamptz( - default=datetime.datetime(year=2020, month=1, day=1) + default=datetime.datetime(year=2020, month=1, day=1, tzinfo=LOCAL_TZ), + tz=LOCAL_TZ, ) -class CustomTimezone(datetime.tzinfo): - pass - - class TestTimestamptz(TestCase): def setUp(self): MyTable.create_table().run_sync() @@ -45,37 +56,32 @@ def test_timestamptz_timezone_aware(self): """ Test storing a timezone aware timestamp. """ - for tzinfo in ( - datetime.timezone.utc, - tz.gettz("America/New_York"), - ): - created_on = datetime.datetime( - year=2020, - month=1, - day=1, - hour=12, - minute=0, - second=0, - tzinfo=tzinfo, - ) - row = MyTable(created_on=created_on) - row.save().run_sync() - - # Fetch it back from the database - result = ( - MyTable.objects() - .where( - MyTable._meta.primary_key - == getattr(row, MyTable._meta.primary_key._meta.name) - ) - .first() - .run_sync() - ) - assert result is not None - self.assertEqual(result.created_on, created_on) - - # The database converts it to UTC - self.assertEqual(result.created_on.tzinfo, datetime.timezone.utc) + dt_args = dict(year=2020, month=1, day=1, hour=12, minute=0, second=0) + created_on_utc = datetime.datetime(**dt_args, tzinfo=ZoneInfo("UTC")) + created_on_local = datetime.datetime( + **dt_args, tzinfo=ZoneInfo("Europe/Tallinn") + ) + row = MyTable( + created_on_utc=created_on_utc, created_on_local=created_on_local + ) + row.save().run_sync() + + # Fetch it back from the database + p_key = MyTable._meta.primary_key + p_key_name = getattr(row, p_key._meta.name) + result = ( + MyTable.objects().where(eq(p_key, p_key_name)).first().run_sync() + ) + assert result is not None + self.assertEqual(result.created_on_utc, created_on_utc) + self.assertEqual(result.created_on_local, created_on_local) + + # The database stores the datetime of the column in UTC timezone, but + # the column converts it back to the timezone that is defined for it. + self.assertEqual(result.created_on_utc.tzinfo, created_on_utc.tzinfo) + self.assertEqual( + result.created_on_local.tzinfo, created_on_local.tzinfo + ) class TestTimestamptzDefault(TestCase): @@ -89,12 +95,27 @@ def test_timestamptz_default(self): """ Make sure the default value gets created, and can be retrieved. """ - created_on = datetime.datetime.now(tz=datetime.timezone.utc) + created_on = datetime.datetime.now(tz=LOCAL_TZ) + time.sleep(1e-5) + row = MyTableDefault() row.save().run_sync() result = MyTableDefault.objects().first().run_sync() assert result is not None + delta = result.created_on - created_on self.assertLess(delta, datetime.timedelta(seconds=1)) - self.assertEqual(result.created_on.tzinfo, datetime.timezone.utc) + self.assertEqual(result.created_on.tzinfo, created_on.tzinfo) + + delta = result.created_on_offset - created_on + self.assertGreaterEqual(delta, datetime.timedelta(days=1)) + self.assertEqual(result.created_on_offset.tzinfo, created_on.tzinfo) + + delta = created_on - result.created_on_custom + self.assertGreaterEqual(delta, datetime.timedelta(days=delta.days)) + self.assertEqual(result.created_on_custom.tzinfo, created_on.tzinfo) + + delta = created_on - result.created_on_datetime + self.assertGreaterEqual(delta, datetime.timedelta(days=delta.days)) + self.assertEqual(result.created_on_datetime.tzinfo, created_on.tzinfo)