1
0
mirror of https://github.com/django/django.git synced 2024-12-22 17:16:24 +00:00

Refs #18247 -- Fixed filtering on CombinedExpression(output_field=DecimalField()) annotation on SQLite.

This commit is contained in:
Sergey Fedoseev 2017-08-11 02:42:30 +05:00 committed by Tim Graham
parent 660d50805b
commit c3c6c92d76
2 changed files with 19 additions and 9 deletions

View File

@ -10,6 +10,19 @@ from django.utils.deconstruct import deconstructible
from django.utils.functional import cached_property from django.utils.functional import cached_property
class SQLiteNumericMixin:
"""
Some expressions with output_field=DecimalField() must be cast to
numeric to be properly filtered.
"""
def as_sqlite(self, compiler, connection, **extra_context):
sql, params = self.as_sql(compiler, connection, **extra_context)
with suppress(FieldError):
if self.output_field.get_internal_type() == 'DecimalField':
sql = 'CAST(%s AS NUMERIC)' % sql
return sql, params
class Combinable: class Combinable:
""" """
Provide the ability to combine one or two objects with Provide the ability to combine one or two objects with
@ -352,7 +365,7 @@ class Expression(BaseExpression, Combinable):
pass pass
class CombinedExpression(Expression): class CombinedExpression(SQLiteNumericMixin, Expression):
def __init__(self, lhs, connector, rhs, output_field=None): def __init__(self, lhs, connector, rhs, output_field=None):
super().__init__(output_field=output_field) super().__init__(output_field=output_field)
@ -506,7 +519,7 @@ class OuterRef(F):
return self return self
class Func(Expression): class Func(SQLiteNumericMixin, Expression):
"""An SQL function call.""" """An SQL function call."""
function = None function = None
template = '%(function)s(%(expressions)s)' template = '%(function)s(%(expressions)s)'
@ -574,13 +587,6 @@ class Func(Expression):
data['expressions'] = data['field'] = arg_joiner.join(sql_parts) data['expressions'] = data['field'] = arg_joiner.join(sql_parts)
return template % data, params return template % data, params
def as_sqlite(self, compiler, connection, **extra_context):
sql, params = self.as_sql(compiler, connection, **extra_context)
with suppress(FieldError):
if self.output_field.get_internal_type() == 'DecimalField':
sql = 'CAST(%s AS NUMERIC)' % sql
return sql, params
def copy(self): def copy(self):
copy = super().copy() copy = super().copy()
copy.source_expressions = self.source_expressions[:] copy.source_expressions = self.source_expressions[:]

View File

@ -244,6 +244,10 @@ class NonAggregateAnnotationTestCase(TestCase):
sum_rating=Sum('rating') sum_rating=Sum('rating')
).filter(sum_rating=F('nope'))) ).filter(sum_rating=F('nope')))
def test_filter_decimal_annotation(self):
qs = Book.objects.annotate(new_price=F('price') + 1).filter(new_price=Decimal(31)).values_list('new_price')
self.assertEqual(qs.get(), (Decimal(31),))
def test_combined_annotation_commutative(self): def test_combined_annotation_commutative(self):
book1 = Book.objects.annotate(adjusted_rating=F('rating') + 2).get(pk=self.b1.pk) book1 = Book.objects.annotate(adjusted_rating=F('rating') + 2).get(pk=self.b1.pk)
book2 = Book.objects.annotate(adjusted_rating=2 + F('rating')).get(pk=self.b1.pk) book2 = Book.objects.annotate(adjusted_rating=2 + F('rating')).get(pk=self.b1.pk)