From c3c6c92d769d44a98299c462c48a9599c0172e91 Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Fri, 11 Aug 2017 02:42:30 +0500 Subject: [PATCH] Refs #18247 -- Fixed filtering on CombinedExpression(output_field=DecimalField()) annotation on SQLite. --- django/db/models/expressions.py | 24 +++++++++++++++--------- tests/annotations/tests.py | 4 ++++ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 15ea2e1e17..9a996f66b3 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -10,6 +10,19 @@ from django.utils.deconstruct import deconstructible from django.utils.functional import cached_property +class SQLiteNumericMixin: + """ + Some expressions with output_field=DecimalField() must be cast to + numeric to be properly filtered. + """ + def as_sqlite(self, compiler, connection, **extra_context): + sql, params = self.as_sql(compiler, connection, **extra_context) + with suppress(FieldError): + if self.output_field.get_internal_type() == 'DecimalField': + sql = 'CAST(%s AS NUMERIC)' % sql + return sql, params + + class Combinable: """ Provide the ability to combine one or two objects with @@ -352,7 +365,7 @@ class Expression(BaseExpression, Combinable): pass -class CombinedExpression(Expression): +class CombinedExpression(SQLiteNumericMixin, Expression): def __init__(self, lhs, connector, rhs, output_field=None): super().__init__(output_field=output_field) @@ -506,7 +519,7 @@ class OuterRef(F): return self -class Func(Expression): +class Func(SQLiteNumericMixin, Expression): """An SQL function call.""" function = None template = '%(function)s(%(expressions)s)' @@ -574,13 +587,6 @@ class Func(Expression): data['expressions'] = data['field'] = arg_joiner.join(sql_parts) return template % data, params - def as_sqlite(self, compiler, connection, **extra_context): - sql, params = self.as_sql(compiler, connection, **extra_context) - with suppress(FieldError): - if self.output_field.get_internal_type() == 'DecimalField': - sql = 'CAST(%s AS NUMERIC)' % sql - return sql, params - def copy(self): copy = super().copy() copy.source_expressions = self.source_expressions[:] diff --git a/tests/annotations/tests.py b/tests/annotations/tests.py index 1714076e1d..aaf56e79ef 100644 --- a/tests/annotations/tests.py +++ b/tests/annotations/tests.py @@ -244,6 +244,10 @@ class NonAggregateAnnotationTestCase(TestCase): sum_rating=Sum('rating') ).filter(sum_rating=F('nope'))) + def test_filter_decimal_annotation(self): + qs = Book.objects.annotate(new_price=F('price') + 1).filter(new_price=Decimal(31)).values_list('new_price') + self.assertEqual(qs.get(), (Decimal(31),)) + def test_combined_annotation_commutative(self): book1 = Book.objects.annotate(adjusted_rating=F('rating') + 2).get(pk=self.b1.pk) book2 = Book.objects.annotate(adjusted_rating=2 + F('rating')).get(pk=self.b1.pk)