1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

Refs #25287 -- Added support for multiplying and dividing DurationField by scalar values on SQLite.

This commit is contained in:
Tobias Bengfort 2021-04-06 18:14:16 +02:00 committed by Mariusz Felisiak
parent 9e1ccd7283
commit 54e94640ac
7 changed files with 64 additions and 3 deletions

View File

@ -68,6 +68,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"returns JSON": { "returns JSON": {
'schema.tests.SchemaTests.test_func_index_json_key_transform', 'schema.tests.SchemaTests.test_func_index_json_key_transform',
}, },
"MySQL supports multiplying and dividing DurationFields by a "
"scalar value but it's not implemented (#25287).": {
'expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide',
},
} }
if 'ONLY_FULL_GROUP_BY' in self.connection.sql_mode: if 'ONLY_FULL_GROUP_BY' in self.connection.sql_mode:
skips.update({ skips.update({

View File

@ -563,6 +563,7 @@ def _sqlite_format_dtdelta(conn, lhs, rhs):
LHS and RHS can be either: LHS and RHS can be either:
- An integer number of microseconds - An integer number of microseconds
- A string representing a datetime - A string representing a datetime
- A scalar value, e.g. float
""" """
conn = conn.strip() conn = conn.strip()
try: try:
@ -574,8 +575,12 @@ def _sqlite_format_dtdelta(conn, lhs, rhs):
# typecast_timestamp returns a date or a datetime without timezone. # typecast_timestamp returns a date or a datetime without timezone.
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]" # It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
out = str(real_lhs + real_rhs) out = str(real_lhs + real_rhs)
else: elif conn == '-':
out = str(real_lhs - real_rhs) out = str(real_lhs - real_rhs)
elif conn == '*':
out = real_lhs * real_rhs
else:
out = real_lhs / real_rhs
return out return out

View File

@ -351,7 +351,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return super().combine_expression(connector, sub_expressions) return super().combine_expression(connector, sub_expressions)
def combine_duration_expression(self, connector, sub_expressions): def combine_duration_expression(self, connector, sub_expressions):
if connector not in ['+', '-']: if connector not in ['+', '-', '*', '/']:
raise DatabaseError('Invalid connector for timedelta: %s.' % connector) raise DatabaseError('Invalid connector for timedelta: %s.' % connector)
fn_params = ["'%s'" % connector] + sub_expressions fn_params = ["'%s'" % connector] + sub_expressions
if len(fn_params) > 3: if len(fn_params) > 3:

View File

@ -6,7 +6,7 @@ from decimal import Decimal
from uuid import UUID from uuid import UUID
from django.core.exceptions import EmptyResultSet, FieldError from django.core.exceptions import EmptyResultSet, FieldError
from django.db import NotSupportedError, connection from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import fields from django.db.models import fields
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
@ -546,6 +546,24 @@ class DurationExpression(CombinedExpression):
sql = connection.ops.combine_duration_expression(self.connector, expressions) sql = connection.ops.combine_duration_expression(self.connector, expressions)
return expression_wrapper % sql, expression_params return expression_wrapper % sql, expression_params
def as_sqlite(self, compiler, connection, **extra_context):
sql, params = self.as_sql(compiler, connection, **extra_context)
if self.connector in {Combinable.MUL, Combinable.DIV}:
try:
lhs_type = self.lhs.output_field.get_internal_type()
rhs_type = self.rhs.output_field.get_internal_type()
except (AttributeError, FieldError):
pass
else:
allowed_fields = {
'DecimalField', 'DurationField', 'FloatField', 'IntegerField',
}
if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
raise DatabaseError(
f'Invalid arguments for operator {self.connector}.'
)
return sql, params
class TemporalSubtraction(CombinedExpression): class TemporalSubtraction(CombinedExpression):
output_field = fields.DurationField() output_field = fields.DurationField()

View File

@ -233,6 +233,9 @@ Models
* :meth:`.QuerySet.bulk_create` now sets the primary key on objects when using * :meth:`.QuerySet.bulk_create` now sets the primary key on objects when using
SQLite 3.35+. SQLite 3.35+.
* :class:`~django.db.models.DurationField` now supports multiplying and
dividing by scalar values on SQLite.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~

View File

@ -60,6 +60,7 @@ class Experiment(models.Model):
estimated_time = models.DurationField() estimated_time = models.DurationField()
start = models.DateTimeField() start = models.DateTimeField()
end = models.DateTimeField() end = models.DateTimeField()
scalar = models.IntegerField(null=True)
class Meta: class Meta:
db_table = 'expressions_ExPeRiMeNt' db_table = 'expressions_ExPeRiMeNt'

View File

@ -1530,6 +1530,36 @@ class FTimeDeltaTests(TestCase):
)) ))
self.assertIsNone(queryset.first().shifted) self.assertIsNone(queryset.first().shifted)
def test_durationfield_multiply_divide(self):
Experiment.objects.update(scalar=2)
tests = [
(Decimal('2'), 2),
(F('scalar'), 2),
(2, 2),
(3.2, 3.2),
]
for expr, scalar in tests:
with self.subTest(expr=expr):
qs = Experiment.objects.annotate(
multiplied=ExpressionWrapper(
expr * F('estimated_time'),
output_field=DurationField(),
),
divided=ExpressionWrapper(
F('estimated_time') / expr,
output_field=DurationField(),
),
)
for experiment in qs:
self.assertEqual(
experiment.multiplied,
experiment.estimated_time * scalar,
)
self.assertEqual(
experiment.divided,
experiment.estimated_time / scalar,
)
def test_duration_expressions(self): def test_duration_expressions(self):
for delta in self.deltas: for delta in self.deltas:
qs = Experiment.objects.annotate(duration=F('estimated_time') + delta) qs = Experiment.objects.annotate(duration=F('estimated_time') + delta)