mirror of
https://github.com/django/django.git
synced 2024-12-22 17:16:24 +00:00
Fixed #27498 -- Fixed filtering on annotated DecimalField on SQLite.
This commit is contained in:
parent
96181080ba
commit
a4cac17200
@ -2,10 +2,13 @@ import itertools
|
||||
import math
|
||||
import warnings
|
||||
from copy import copy
|
||||
from decimal import Decimal
|
||||
|
||||
from django.core.exceptions import EmptyResultSet
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import DateTimeField, Field, IntegerField
|
||||
from django.db.models.fields import (
|
||||
DateTimeField, DecimalField, Field, IntegerField,
|
||||
)
|
||||
from django.db.models.query_utils import RegisterLookupMixin
|
||||
from django.utils.deprecation import RemovedInDjango20Warning
|
||||
from django.utils.functional import cached_property
|
||||
@ -306,6 +309,40 @@ class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
|
||||
IntegerField.register_lookup(IntegerLessThan)
|
||||
|
||||
|
||||
class DecimalComparisonLookup(object):
|
||||
def as_sqlite(self, compiler, connection):
|
||||
lhs_sql, params = self.process_lhs(compiler, connection)
|
||||
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
|
||||
params.extend(rhs_params)
|
||||
# For comparisons whose lhs is a DecimalField, cast rhs AS NUMERIC
|
||||
# because the rhs will have been converted to a string by the
|
||||
# rev_typecast_decimal() adapter.
|
||||
if isinstance(self.rhs, Decimal):
|
||||
rhs_sql = 'CAST(%s AS NUMERIC)' % rhs_sql
|
||||
rhs_sql = self.get_rhs_op(connection, rhs_sql)
|
||||
return '%s %s' % (lhs_sql, rhs_sql), params
|
||||
|
||||
|
||||
@DecimalField.register_lookup
|
||||
class DecimalGreaterThan(DecimalComparisonLookup, GreaterThan):
|
||||
pass
|
||||
|
||||
|
||||
@DecimalField.register_lookup
|
||||
class DecimalGreaterThanOrEqual(DecimalComparisonLookup, GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
@DecimalField.register_lookup
|
||||
class DecimalLessThan(DecimalComparisonLookup, LessThan):
|
||||
pass
|
||||
|
||||
|
||||
@DecimalField.register_lookup
|
||||
class DecimalLessThanOrEqual(DecimalComparisonLookup, LessThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
lookup_name = 'in'
|
||||
|
||||
|
@ -86,3 +86,13 @@ class MyISAMArticle(models.Model):
|
||||
class Meta:
|
||||
db_table = 'myisam_article'
|
||||
managed = False
|
||||
|
||||
|
||||
class Product(models.Model):
|
||||
name = models.CharField(max_length=80)
|
||||
qty_target = models.DecimalField(max_digits=6, decimal_places=2)
|
||||
|
||||
|
||||
class Stock(models.Model):
|
||||
product = models.ForeignKey(Product, models.CASCADE)
|
||||
qty_available = models.DecimalField(max_digits=6, decimal_places=2)
|
||||
|
38
tests/lookup/test_decimalfield.py
Normal file
38
tests/lookup/test_decimalfield.py
Normal file
@ -0,0 +1,38 @@
|
||||
from django.db.models.aggregates import Sum
|
||||
from django.db.models.expressions import F
|
||||
from django.test import TestCase
|
||||
|
||||
from .models import Product, Stock
|
||||
|
||||
|
||||
class DecimalFieldLookupTests(TestCase):
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
cls.p1 = Product.objects.create(name='Product1', qty_target=10)
|
||||
Stock.objects.create(product=cls.p1, qty_available=5)
|
||||
Stock.objects.create(product=cls.p1, qty_available=6)
|
||||
cls.p2 = Product.objects.create(name='Product2', qty_target=10)
|
||||
Stock.objects.create(product=cls.p2, qty_available=5)
|
||||
Stock.objects.create(product=cls.p2, qty_available=5)
|
||||
cls.p3 = Product.objects.create(name='Product3', qty_target=10)
|
||||
Stock.objects.create(product=cls.p3, qty_available=5)
|
||||
Stock.objects.create(product=cls.p3, qty_available=4)
|
||||
cls.queryset = Product.objects.annotate(
|
||||
qty_available_sum=Sum('stock__qty_available'),
|
||||
).annotate(qty_needed=F('qty_target') - F('qty_available_sum'))
|
||||
|
||||
def test_gt(self):
|
||||
qs = self.queryset.filter(qty_needed__gt=0)
|
||||
self.assertCountEqual(qs, [self.p3])
|
||||
|
||||
def test_gte(self):
|
||||
qs = self.queryset.filter(qty_needed__gte=0)
|
||||
self.assertCountEqual(qs, [self.p2, self.p3])
|
||||
|
||||
def test_lt(self):
|
||||
qs = self.queryset.filter(qty_needed__lt=0)
|
||||
self.assertCountEqual(qs, [self.p1])
|
||||
|
||||
def test_lte(self):
|
||||
qs = self.queryset.filter(qty_needed__lte=0)
|
||||
self.assertCountEqual(qs, [self.p1, self.p2])
|
Loading…
Reference in New Issue
Block a user