From 54e94640ace261b14cf8cdb1fae3dc6f068a5f87 Mon Sep 17 00:00:00 2001 From: Tobias Bengfort Date: Tue, 6 Apr 2021 18:14:16 +0200 Subject: [PATCH] Refs #25287 -- Added support for multiplying and dividing DurationField by scalar values on SQLite. --- django/db/backends/mysql/features.py | 4 ++++ django/db/backends/sqlite3/base.py | 7 +++++- django/db/backends/sqlite3/operations.py | 2 +- django/db/models/expressions.py | 20 +++++++++++++++- docs/releases/4.0.txt | 3 +++ tests/expressions/models.py | 1 + tests/expressions/tests.py | 30 ++++++++++++++++++++++++ 7 files changed, 64 insertions(+), 3 deletions(-) diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py index 419b2ba6f0..eb80b2e543 100644 --- a/django/db/backends/mysql/features.py +++ b/django/db/backends/mysql/features.py @@ -68,6 +68,10 @@ class DatabaseFeatures(BaseDatabaseFeatures): "returns JSON": { 'schema.tests.SchemaTests.test_func_index_json_key_transform', }, + "MySQL supports multiplying and dividing DurationFields by a " + "scalar value but it's not implemented (#25287).": { + 'expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide', + }, } if 'ONLY_FULL_GROUP_BY' in self.connection.sql_mode: skips.update({ diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 35466189e6..a64190f0d0 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -563,6 +563,7 @@ def _sqlite_format_dtdelta(conn, lhs, rhs): LHS and RHS can be either: - An integer number of microseconds - A string representing a datetime + - A scalar value, e.g. float """ conn = conn.strip() try: @@ -574,8 +575,12 @@ def _sqlite_format_dtdelta(conn, lhs, rhs): # typecast_timestamp returns a date or a datetime without timezone. # It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]" out = str(real_lhs + real_rhs) - else: + elif conn == '-': out = str(real_lhs - real_rhs) + elif conn == '*': + out = real_lhs * real_rhs + else: + out = real_lhs / real_rhs return out diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index 95484931cf..90a4241803 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -351,7 +351,7 @@ class DatabaseOperations(BaseDatabaseOperations): return super().combine_expression(connector, sub_expressions) def combine_duration_expression(self, connector, sub_expressions): - if connector not in ['+', '-']: + if connector not in ['+', '-', '*', '/']: raise DatabaseError('Invalid connector for timedelta: %s.' % connector) fn_params = ["'%s'" % connector] + sub_expressions if len(fn_params) > 3: diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 4ecae5f02d..528d988e85 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -6,7 +6,7 @@ from decimal import Decimal from uuid import UUID from django.core.exceptions import EmptyResultSet, FieldError -from django.db import NotSupportedError, connection +from django.db import DatabaseError, NotSupportedError, connection from django.db.models import fields from django.db.models.constants import LOOKUP_SEP from django.db.models.query_utils import Q @@ -546,6 +546,24 @@ class DurationExpression(CombinedExpression): sql = connection.ops.combine_duration_expression(self.connector, expressions) return expression_wrapper % sql, expression_params + def as_sqlite(self, compiler, connection, **extra_context): + sql, params = self.as_sql(compiler, connection, **extra_context) + if self.connector in {Combinable.MUL, Combinable.DIV}: + try: + lhs_type = self.lhs.output_field.get_internal_type() + rhs_type = self.rhs.output_field.get_internal_type() + except (AttributeError, FieldError): + pass + else: + allowed_fields = { + 'DecimalField', 'DurationField', 'FloatField', 'IntegerField', + } + if lhs_type not in allowed_fields or rhs_type not in allowed_fields: + raise DatabaseError( + f'Invalid arguments for operator {self.connector}.' + ) + return sql, params + class TemporalSubtraction(CombinedExpression): output_field = fields.DurationField() diff --git a/docs/releases/4.0.txt b/docs/releases/4.0.txt index 3792c4b716..faf3d0fa1e 100644 --- a/docs/releases/4.0.txt +++ b/docs/releases/4.0.txt @@ -233,6 +233,9 @@ Models * :meth:`.QuerySet.bulk_create` now sets the primary key on objects when using SQLite 3.35+. +* :class:`~django.db.models.DurationField` now supports multiplying and + dividing by scalar values on SQLite. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/expressions/models.py b/tests/expressions/models.py index 02836e653e..938e623d60 100644 --- a/tests/expressions/models.py +++ b/tests/expressions/models.py @@ -60,6 +60,7 @@ class Experiment(models.Model): estimated_time = models.DurationField() start = models.DateTimeField() end = models.DateTimeField() + scalar = models.IntegerField(null=True) class Meta: db_table = 'expressions_ExPeRiMeNt' diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 9b88c94d7b..0585805a8b 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -1530,6 +1530,36 @@ class FTimeDeltaTests(TestCase): )) self.assertIsNone(queryset.first().shifted) + def test_durationfield_multiply_divide(self): + Experiment.objects.update(scalar=2) + tests = [ + (Decimal('2'), 2), + (F('scalar'), 2), + (2, 2), + (3.2, 3.2), + ] + for expr, scalar in tests: + with self.subTest(expr=expr): + qs = Experiment.objects.annotate( + multiplied=ExpressionWrapper( + expr * F('estimated_time'), + output_field=DurationField(), + ), + divided=ExpressionWrapper( + F('estimated_time') / expr, + output_field=DurationField(), + ), + ) + for experiment in qs: + self.assertEqual( + experiment.multiplied, + experiment.estimated_time * scalar, + ) + self.assertEqual( + experiment.divided, + experiment.estimated_time / scalar, + ) + def test_duration_expressions(self): for delta in self.deltas: qs = Experiment.objects.annotate(duration=F('estimated_time') + delta)