From ee005328c8eec5c013c6bf9d6fbb2ae9d540df14 Mon Sep 17 00:00:00 2001 From: David-Wobrock Date: Sun, 4 Oct 2020 19:28:21 +0200 Subject: [PATCH] Fixed #31640 -- Made Trunc() truncate datetimes to Date/TimeField in a specific timezone. --- django/db/backends/base/operations.py | 18 +++-- django/db/backends/mysql/operations.py | 8 +- django/db/backends/oracle/operations.py | 8 +- django/db/backends/postgresql/operations.py | 8 +- django/db/backends/sqlite3/base.py | 22 ++--- django/db/backends/sqlite3/operations.py | 18 +++-- django/db/models/functions/datetime.py | 10 ++- docs/releases/3.2.txt | 4 + .../datetime/test_extract_trunc.py | 81 +++++++++++++++++++ 9 files changed, 145 insertions(+), 32 deletions(-) diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index 2319cb6306..0fcc607bcf 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -99,11 +99,14 @@ class BaseDatabaseOperations: """ raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_extract_sql() method') - def date_trunc_sql(self, lookup_type, field_name): + def date_trunc_sql(self, lookup_type, field_name, tzname=None): """ Given a lookup_type of 'year', 'month', or 'day', return the SQL that - truncates the given date field field_name to a date object with only - the given specificity. + truncates the given date or datetime field field_name to a date object + with only the given specificity. + + If `tzname` is provided, the given value is truncated in a specific + timezone. """ raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_trunc_sql() method.') @@ -138,11 +141,14 @@ class BaseDatabaseOperations: """ raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() method') - def time_trunc_sql(self, lookup_type, field_name): + def time_trunc_sql(self, lookup_type, field_name, tzname=None): """ Given a lookup_type of 'hour', 'minute' or 'second', return the SQL - that truncates the given time field field_name to a time object with - only the given specificity. + that truncates the given time or datetime field field_name to a time + object with only the given specificity. + + If `tzname` is provided, the given value is truncated in a specific + timezone. """ raise NotImplementedError('subclasses of BaseDatabaseOperations may require a time_trunc_sql() method') diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py index 871c533870..5d2a981226 100644 --- a/django/db/backends/mysql/operations.py +++ b/django/db/backends/mysql/operations.py @@ -55,7 +55,8 @@ class DatabaseOperations(BaseDatabaseOperations): # EXTRACT returns 1-53 based on ISO-8601 for the week number. return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name) - def date_trunc_sql(self, lookup_type, field_name): + def date_trunc_sql(self, lookup_type, field_name, tzname=None): + field_name = self._convert_field_to_tz(field_name, tzname) fields = { 'year': '%%Y-01-01', 'month': '%%Y-%%m-01', @@ -82,7 +83,7 @@ class DatabaseOperations(BaseDatabaseOperations): return tzname def _convert_field_to_tz(self, field_name, tzname): - if settings.USE_TZ and self.connection.timezone_name != tzname: + if tzname and settings.USE_TZ and self.connection.timezone_name != tzname: field_name = "CONVERT_TZ(%s, '%s', '%s')" % ( field_name, self.connection.timezone_name, @@ -128,7 +129,8 @@ class DatabaseOperations(BaseDatabaseOperations): sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str) return sql - def time_trunc_sql(self, lookup_type, field_name): + def time_trunc_sql(self, lookup_type, field_name, tzname=None): + field_name = self._convert_field_to_tz(field_name, tzname) fields = { 'hour': '%%H:00:00', 'minute': '%%H:%%i:00', diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index cfce38b189..b829d2bd9b 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -89,7 +89,8 @@ END; # https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/EXTRACT-datetime.html return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name) - def date_trunc_sql(self, lookup_type, field_name): + def date_trunc_sql(self, lookup_type, field_name, tzname=None): + field_name = self._convert_field_to_tz(field_name, tzname) # https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html if lookup_type in ('year', 'month'): return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper()) @@ -114,7 +115,7 @@ END; return tzname def _convert_field_to_tz(self, field_name, tzname): - if not settings.USE_TZ: + if not (settings.USE_TZ and tzname): return field_name if not self._tzname_re.match(tzname): raise ValueError("Invalid time zone name: %s" % tzname) @@ -161,10 +162,11 @@ END; sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision. return sql - def time_trunc_sql(self, lookup_type, field_name): + def time_trunc_sql(self, lookup_type, field_name, tzname=None): # The implementation is similar to `datetime_trunc_sql` as both # `DateTimeField` and `TimeField` are stored as TIMESTAMP where # the date part of the later is ignored. + field_name = self._convert_field_to_tz(field_name, tzname) if lookup_type == 'hour': sql = "TRUNC(%s, 'HH24')" % field_name elif lookup_type == 'minute': diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 1ce5755bf5..8d19872bea 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -38,7 +38,8 @@ class DatabaseOperations(BaseDatabaseOperations): else: return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name) - def date_trunc_sql(self, lookup_type, field_name): + def date_trunc_sql(self, lookup_type, field_name, tzname=None): + field_name = self._convert_field_to_tz(field_name, tzname) # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name) @@ -50,7 +51,7 @@ class DatabaseOperations(BaseDatabaseOperations): return tzname def _convert_field_to_tz(self, field_name, tzname): - if settings.USE_TZ: + if tzname and settings.USE_TZ: field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname)) return field_name @@ -71,7 +72,8 @@ class DatabaseOperations(BaseDatabaseOperations): # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name) - def time_trunc_sql(self, lookup_type, field_name): + def time_trunc_sql(self, lookup_type, field_name, tzname=None): + field_name = self._convert_field_to_tz(field_name, tzname) return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name) def deferrable_sql(self): diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 608a2b2a95..13b78edc52 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -213,13 +213,13 @@ class DatabaseWrapper(BaseDatabaseWrapper): else: create_deterministic_function = conn.create_function create_deterministic_function('django_date_extract', 2, _sqlite_datetime_extract) - create_deterministic_function('django_date_trunc', 2, _sqlite_date_trunc) + create_deterministic_function('django_date_trunc', 4, _sqlite_date_trunc) create_deterministic_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date) create_deterministic_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time) create_deterministic_function('django_datetime_extract', 4, _sqlite_datetime_extract) create_deterministic_function('django_datetime_trunc', 4, _sqlite_datetime_trunc) create_deterministic_function('django_time_extract', 2, _sqlite_time_extract) - create_deterministic_function('django_time_trunc', 2, _sqlite_time_trunc) + create_deterministic_function('django_time_trunc', 4, _sqlite_time_trunc) create_deterministic_function('django_time_diff', 2, _sqlite_time_diff) create_deterministic_function('django_timestamp_diff', 2, _sqlite_timestamp_diff) create_deterministic_function('django_format_dtdelta', 3, _sqlite_format_dtdelta) @@ -445,8 +445,8 @@ def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None): return dt -def _sqlite_date_trunc(lookup_type, dt): - dt = _sqlite_datetime_parse(dt) +def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname): + dt = _sqlite_datetime_parse(dt, tzname, conn_tzname) if dt is None: return None if lookup_type == 'year': @@ -463,13 +463,17 @@ def _sqlite_date_trunc(lookup_type, dt): return "%i-%02i-%02i" % (dt.year, dt.month, dt.day) -def _sqlite_time_trunc(lookup_type, dt): +def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname): if dt is None: return None - try: - dt = backend_utils.typecast_time(dt) - except (ValueError, TypeError): - return None + dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname) + if dt_parsed is None: + try: + dt = backend_utils.typecast_time(dt) + except (ValueError, TypeError): + return None + else: + dt = dt_parsed if lookup_type == 'hour': return "%02i:00:00" % dt.hour elif lookup_type == 'minute': diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index 1f77b3109f..71ef000c93 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -77,14 +77,22 @@ class DatabaseOperations(BaseDatabaseOperations): """Do nothing since formatting is handled in the custom function.""" return sql - def date_trunc_sql(self, lookup_type, field_name): - return "django_date_trunc('%s', %s)" % (lookup_type.lower(), field_name) + def date_trunc_sql(self, lookup_type, field_name, tzname=None): + return "django_date_trunc('%s', %s, %s, %s)" % ( + lookup_type.lower(), + field_name, + *self._convert_tznames_to_sql(tzname), + ) - def time_trunc_sql(self, lookup_type, field_name): - return "django_time_trunc('%s', %s)" % (lookup_type.lower(), field_name) + def time_trunc_sql(self, lookup_type, field_name, tzname=None): + return "django_time_trunc('%s', %s, %s, %s)" % ( + lookup_type.lower(), + field_name, + *self._convert_tznames_to_sql(tzname), + ) def _convert_tznames_to_sql(self, tzname): - if settings.USE_TZ: + if tzname and settings.USE_TZ: return "'%s'" % tzname, "'%s'" % self.connection.timezone_name return 'NULL', 'NULL' diff --git a/django/db/models/functions/datetime.py b/django/db/models/functions/datetime.py index 5075bd3d83..90e6f41be0 100644 --- a/django/db/models/functions/datetime.py +++ b/django/db/models/functions/datetime.py @@ -193,13 +193,17 @@ class TruncBase(TimezoneMixin, Transform): def as_sql(self, compiler, connection): inner_sql, inner_params = compiler.compile(self.lhs) - if isinstance(self.output_field, DateTimeField): + tzname = None + if isinstance(self.lhs.output_field, DateTimeField): tzname = self.get_tzname() + elif self.tzinfo is not None: + raise ValueError('tzinfo can only be used with DateTimeField.') + if isinstance(self.output_field, DateTimeField): sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname) elif isinstance(self.output_field, DateField): - sql = connection.ops.date_trunc_sql(self.kind, inner_sql) + sql = connection.ops.date_trunc_sql(self.kind, inner_sql, tzname) elif isinstance(self.output_field, TimeField): - sql = connection.ops.time_trunc_sql(self.kind, inner_sql) + sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname) else: raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.') return sql, inner_params diff --git a/docs/releases/3.2.txt b/docs/releases/3.2.txt index f4ef77e435..0c233f796b 100644 --- a/docs/releases/3.2.txt +++ b/docs/releases/3.2.txt @@ -458,6 +458,10 @@ backends. * ``DatabaseOperations.random_function_sql()`` is removed in favor of the new :class:`~django.db.models.functions.Random` database function. +* ``DatabaseOperations.date_trunc_sql()`` and + ``DatabaseOperations.time_trunc_sql()`` now take the optional ``tzname`` + argument in order to truncate in a specific timezone. + :mod:`django.contrib.admin` --------------------------- diff --git a/tests/db_functions/datetime/test_extract_trunc.py b/tests/db_functions/datetime/test_extract_trunc.py index 58ca9b52b5..035900da93 100644 --- a/tests/db_functions/datetime/test_extract_trunc.py +++ b/tests/db_functions/datetime/test_extract_trunc.py @@ -672,6 +672,18 @@ class DateFunctionTests(TestCase): lambda m: (m.start_datetime, m.truncated) ) + def test_datetime_to_time_kind(kind): + self.assertQuerysetEqual( + DTModel.objects.annotate( + truncated=Trunc('start_datetime', kind, output_field=TimeField()), + ).order_by('start_datetime'), + [ + (start_datetime, truncate_to(start_datetime.time(), kind)), + (end_datetime, truncate_to(end_datetime.time(), kind)), + ], + lambda m: (m.start_datetime, m.truncated), + ) + test_date_kind('year') test_date_kind('quarter') test_date_kind('month') @@ -688,6 +700,9 @@ class DateFunctionTests(TestCase): test_datetime_kind('hour') test_datetime_kind('minute') test_datetime_kind('second') + test_datetime_to_time_kind('hour') + test_datetime_to_time_kind('minute') + test_datetime_to_time_kind('second') qs = DTModel.objects.filter(start_datetime__date=Trunc('start_datetime', 'day', output_field=DateField())) self.assertEqual(qs.count(), 2) @@ -1205,6 +1220,60 @@ class DateFunctionWithTimeZoneTests(DateFunctionTests): lambda m: (m.start_datetime, m.truncated) ) + def test_datetime_to_date_kind(kind): + self.assertQuerysetEqual( + DTModel.objects.annotate( + truncated=Trunc( + 'start_datetime', + kind, + output_field=DateField(), + tzinfo=melb, + ), + ).order_by('start_datetime'), + [ + ( + start_datetime, + truncate_to(start_datetime.astimezone(melb).date(), kind), + ), + ( + end_datetime, + truncate_to(end_datetime.astimezone(melb).date(), kind), + ), + ], + lambda m: (m.start_datetime, m.truncated), + ) + + def test_datetime_to_time_kind(kind): + self.assertQuerysetEqual( + DTModel.objects.annotate( + truncated=Trunc( + 'start_datetime', + kind, + output_field=TimeField(), + tzinfo=melb, + ) + ).order_by('start_datetime'), + [ + ( + start_datetime, + truncate_to(start_datetime.astimezone(melb).time(), kind), + ), + ( + end_datetime, + truncate_to(end_datetime.astimezone(melb).time(), kind), + ), + ], + lambda m: (m.start_datetime, m.truncated), + ) + + test_datetime_to_date_kind('year') + test_datetime_to_date_kind('quarter') + test_datetime_to_date_kind('month') + test_datetime_to_date_kind('week') + test_datetime_to_date_kind('day') + test_datetime_to_time_kind('hour') + test_datetime_to_time_kind('minute') + test_datetime_to_time_kind('second') test_datetime_kind('year') test_datetime_kind('quarter') test_datetime_kind('month') @@ -1216,3 +1285,15 @@ class DateFunctionWithTimeZoneTests(DateFunctionTests): qs = DTModel.objects.filter(start_datetime__date=Trunc('start_datetime', 'day', output_field=DateField())) self.assertEqual(qs.count(), 2) + + def test_trunc_invalid_field_with_timezone(self): + melb = pytz.timezone('Australia/Melbourne') + msg = 'tzinfo can only be used with DateTimeField.' + with self.assertRaisesMessage(ValueError, msg): + DTModel.objects.annotate( + day_melb=Trunc('start_date', 'day', tzinfo=melb), + ).get() + with self.assertRaisesMessage(ValueError, msg): + DTModel.objects.annotate( + hour_melb=Trunc('start_time', 'hour', tzinfo=melb), + ).get()