From a4cac1720034920351291359d0bb2177d8c8a4d5 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Sat, 10 Dec 2016 18:05:34 +0000 Subject: [PATCH] Fixed #27498 -- Fixed filtering on annotated DecimalField on SQLite. --- django/db/models/lookups.py | 39 ++++++++++++++++++++++++++++++- tests/lookup/models.py | 10 ++++++++ tests/lookup/test_decimalfield.py | 38 ++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 tests/lookup/test_decimalfield.py diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 78cea19037..5429204d63 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -2,10 +2,13 @@ import itertools import math import warnings from copy import copy +from decimal import Decimal from django.core.exceptions import EmptyResultSet from django.db.models.expressions import Func, Value -from django.db.models.fields import DateTimeField, Field, IntegerField +from django.db.models.fields import ( + DateTimeField, DecimalField, Field, IntegerField, +) from django.db.models.query_utils import RegisterLookupMixin from django.utils.deprecation import RemovedInDjango20Warning from django.utils.functional import cached_property @@ -306,6 +309,40 @@ class IntegerLessThan(IntegerFieldFloatRounding, LessThan): IntegerField.register_lookup(IntegerLessThan) +class DecimalComparisonLookup(object): + def as_sqlite(self, compiler, connection): + lhs_sql, params = self.process_lhs(compiler, connection) + rhs_sql, rhs_params = self.process_rhs(compiler, connection) + params.extend(rhs_params) + # For comparisons whose lhs is a DecimalField, cast rhs AS NUMERIC + # because the rhs will have been converted to a string by the + # rev_typecast_decimal() adapter. + if isinstance(self.rhs, Decimal): + rhs_sql = 'CAST(%s AS NUMERIC)' % rhs_sql + rhs_sql = self.get_rhs_op(connection, rhs_sql) + return '%s %s' % (lhs_sql, rhs_sql), params + + +@DecimalField.register_lookup +class DecimalGreaterThan(DecimalComparisonLookup, GreaterThan): + pass + + +@DecimalField.register_lookup +class DecimalGreaterThanOrEqual(DecimalComparisonLookup, GreaterThanOrEqual): + pass + + +@DecimalField.register_lookup +class DecimalLessThan(DecimalComparisonLookup, LessThan): + pass + + +@DecimalField.register_lookup +class DecimalLessThanOrEqual(DecimalComparisonLookup, LessThanOrEqual): + pass + + class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): lookup_name = 'in' diff --git a/tests/lookup/models.py b/tests/lookup/models.py index 57389e90de..e8479f11d1 100644 --- a/tests/lookup/models.py +++ b/tests/lookup/models.py @@ -86,3 +86,13 @@ class MyISAMArticle(models.Model): class Meta: db_table = 'myisam_article' managed = False + + +class Product(models.Model): + name = models.CharField(max_length=80) + qty_target = models.DecimalField(max_digits=6, decimal_places=2) + + +class Stock(models.Model): + product = models.ForeignKey(Product, models.CASCADE) + qty_available = models.DecimalField(max_digits=6, decimal_places=2) diff --git a/tests/lookup/test_decimalfield.py b/tests/lookup/test_decimalfield.py new file mode 100644 index 0000000000..c6d17bce84 --- /dev/null +++ b/tests/lookup/test_decimalfield.py @@ -0,0 +1,38 @@ +from django.db.models.aggregates import Sum +from django.db.models.expressions import F +from django.test import TestCase + +from .models import Product, Stock + + +class DecimalFieldLookupTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.p1 = Product.objects.create(name='Product1', qty_target=10) + Stock.objects.create(product=cls.p1, qty_available=5) + Stock.objects.create(product=cls.p1, qty_available=6) + cls.p2 = Product.objects.create(name='Product2', qty_target=10) + Stock.objects.create(product=cls.p2, qty_available=5) + Stock.objects.create(product=cls.p2, qty_available=5) + cls.p3 = Product.objects.create(name='Product3', qty_target=10) + Stock.objects.create(product=cls.p3, qty_available=5) + Stock.objects.create(product=cls.p3, qty_available=4) + cls.queryset = Product.objects.annotate( + qty_available_sum=Sum('stock__qty_available'), + ).annotate(qty_needed=F('qty_target') - F('qty_available_sum')) + + def test_gt(self): + qs = self.queryset.filter(qty_needed__gt=0) + self.assertCountEqual(qs, [self.p3]) + + def test_gte(self): + qs = self.queryset.filter(qty_needed__gte=0) + self.assertCountEqual(qs, [self.p2, self.p3]) + + def test_lt(self): + qs = self.queryset.filter(qty_needed__lt=0) + self.assertCountEqual(qs, [self.p1]) + + def test_lte(self): + qs = self.queryset.filter(qty_needed__lte=0) + self.assertCountEqual(qs, [self.p1, self.p2])