mirror of
https://github.com/django/django.git
synced 2024-12-22 17:16:24 +00:00
Refs #25287 -- Added support for multiplying and dividing DurationField by scalar values on SQLite.
This commit is contained in:
parent
9e1ccd7283
commit
54e94640ac
@ -68,6 +68,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||||||
"returns JSON": {
|
"returns JSON": {
|
||||||
'schema.tests.SchemaTests.test_func_index_json_key_transform',
|
'schema.tests.SchemaTests.test_func_index_json_key_transform',
|
||||||
},
|
},
|
||||||
|
"MySQL supports multiplying and dividing DurationFields by a "
|
||||||
|
"scalar value but it's not implemented (#25287).": {
|
||||||
|
'expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide',
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if 'ONLY_FULL_GROUP_BY' in self.connection.sql_mode:
|
if 'ONLY_FULL_GROUP_BY' in self.connection.sql_mode:
|
||||||
skips.update({
|
skips.update({
|
||||||
|
@ -563,6 +563,7 @@ def _sqlite_format_dtdelta(conn, lhs, rhs):
|
|||||||
LHS and RHS can be either:
|
LHS and RHS can be either:
|
||||||
- An integer number of microseconds
|
- An integer number of microseconds
|
||||||
- A string representing a datetime
|
- A string representing a datetime
|
||||||
|
- A scalar value, e.g. float
|
||||||
"""
|
"""
|
||||||
conn = conn.strip()
|
conn = conn.strip()
|
||||||
try:
|
try:
|
||||||
@ -574,8 +575,12 @@ def _sqlite_format_dtdelta(conn, lhs, rhs):
|
|||||||
# typecast_timestamp returns a date or a datetime without timezone.
|
# typecast_timestamp returns a date or a datetime without timezone.
|
||||||
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
|
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
|
||||||
out = str(real_lhs + real_rhs)
|
out = str(real_lhs + real_rhs)
|
||||||
else:
|
elif conn == '-':
|
||||||
out = str(real_lhs - real_rhs)
|
out = str(real_lhs - real_rhs)
|
||||||
|
elif conn == '*':
|
||||||
|
out = real_lhs * real_rhs
|
||||||
|
else:
|
||||||
|
out = real_lhs / real_rhs
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -351,7 +351,7 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||||||
return super().combine_expression(connector, sub_expressions)
|
return super().combine_expression(connector, sub_expressions)
|
||||||
|
|
||||||
def combine_duration_expression(self, connector, sub_expressions):
|
def combine_duration_expression(self, connector, sub_expressions):
|
||||||
if connector not in ['+', '-']:
|
if connector not in ['+', '-', '*', '/']:
|
||||||
raise DatabaseError('Invalid connector for timedelta: %s.' % connector)
|
raise DatabaseError('Invalid connector for timedelta: %s.' % connector)
|
||||||
fn_params = ["'%s'" % connector] + sub_expressions
|
fn_params = ["'%s'" % connector] + sub_expressions
|
||||||
if len(fn_params) > 3:
|
if len(fn_params) > 3:
|
||||||
|
@ -6,7 +6,7 @@ from decimal import Decimal
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from django.core.exceptions import EmptyResultSet, FieldError
|
from django.core.exceptions import EmptyResultSet, FieldError
|
||||||
from django.db import NotSupportedError, connection
|
from django.db import DatabaseError, NotSupportedError, connection
|
||||||
from django.db.models import fields
|
from django.db.models import fields
|
||||||
from django.db.models.constants import LOOKUP_SEP
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
from django.db.models.query_utils import Q
|
from django.db.models.query_utils import Q
|
||||||
@ -546,6 +546,24 @@ class DurationExpression(CombinedExpression):
|
|||||||
sql = connection.ops.combine_duration_expression(self.connector, expressions)
|
sql = connection.ops.combine_duration_expression(self.connector, expressions)
|
||||||
return expression_wrapper % sql, expression_params
|
return expression_wrapper % sql, expression_params
|
||||||
|
|
||||||
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
|
sql, params = self.as_sql(compiler, connection, **extra_context)
|
||||||
|
if self.connector in {Combinable.MUL, Combinable.DIV}:
|
||||||
|
try:
|
||||||
|
lhs_type = self.lhs.output_field.get_internal_type()
|
||||||
|
rhs_type = self.rhs.output_field.get_internal_type()
|
||||||
|
except (AttributeError, FieldError):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
allowed_fields = {
|
||||||
|
'DecimalField', 'DurationField', 'FloatField', 'IntegerField',
|
||||||
|
}
|
||||||
|
if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
|
||||||
|
raise DatabaseError(
|
||||||
|
f'Invalid arguments for operator {self.connector}.'
|
||||||
|
)
|
||||||
|
return sql, params
|
||||||
|
|
||||||
|
|
||||||
class TemporalSubtraction(CombinedExpression):
|
class TemporalSubtraction(CombinedExpression):
|
||||||
output_field = fields.DurationField()
|
output_field = fields.DurationField()
|
||||||
|
@ -233,6 +233,9 @@ Models
|
|||||||
* :meth:`.QuerySet.bulk_create` now sets the primary key on objects when using
|
* :meth:`.QuerySet.bulk_create` now sets the primary key on objects when using
|
||||||
SQLite 3.35+.
|
SQLite 3.35+.
|
||||||
|
|
||||||
|
* :class:`~django.db.models.DurationField` now supports multiplying and
|
||||||
|
dividing by scalar values on SQLite.
|
||||||
|
|
||||||
Requests and Responses
|
Requests and Responses
|
||||||
~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
@ -60,6 +60,7 @@ class Experiment(models.Model):
|
|||||||
estimated_time = models.DurationField()
|
estimated_time = models.DurationField()
|
||||||
start = models.DateTimeField()
|
start = models.DateTimeField()
|
||||||
end = models.DateTimeField()
|
end = models.DateTimeField()
|
||||||
|
scalar = models.IntegerField(null=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
db_table = 'expressions_ExPeRiMeNt'
|
db_table = 'expressions_ExPeRiMeNt'
|
||||||
|
@ -1530,6 +1530,36 @@ class FTimeDeltaTests(TestCase):
|
|||||||
))
|
))
|
||||||
self.assertIsNone(queryset.first().shifted)
|
self.assertIsNone(queryset.first().shifted)
|
||||||
|
|
||||||
|
def test_durationfield_multiply_divide(self):
|
||||||
|
Experiment.objects.update(scalar=2)
|
||||||
|
tests = [
|
||||||
|
(Decimal('2'), 2),
|
||||||
|
(F('scalar'), 2),
|
||||||
|
(2, 2),
|
||||||
|
(3.2, 3.2),
|
||||||
|
]
|
||||||
|
for expr, scalar in tests:
|
||||||
|
with self.subTest(expr=expr):
|
||||||
|
qs = Experiment.objects.annotate(
|
||||||
|
multiplied=ExpressionWrapper(
|
||||||
|
expr * F('estimated_time'),
|
||||||
|
output_field=DurationField(),
|
||||||
|
),
|
||||||
|
divided=ExpressionWrapper(
|
||||||
|
F('estimated_time') / expr,
|
||||||
|
output_field=DurationField(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for experiment in qs:
|
||||||
|
self.assertEqual(
|
||||||
|
experiment.multiplied,
|
||||||
|
experiment.estimated_time * scalar,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
experiment.divided,
|
||||||
|
experiment.estimated_time / scalar,
|
||||||
|
)
|
||||||
|
|
||||||
def test_duration_expressions(self):
|
def test_duration_expressions(self):
|
||||||
for delta in self.deltas:
|
for delta in self.deltas:
|
||||||
qs = Experiment.objects.annotate(duration=F('estimated_time') + delta)
|
qs = Experiment.objects.annotate(duration=F('estimated_time') + delta)
|
||||||
|
Loading…
Reference in New Issue
Block a user