1
0
mirror of https://github.com/django/django.git synced 2025-10-31 09:41:08 +00:00

Added documentation, polished implementation

This commit is contained in:
Anssi Kääriäinen
2013-12-01 02:14:19 +02:00
parent 32c04357a8
commit 2adf50428d
5 changed files with 322 additions and 23 deletions

View File

@@ -20,16 +20,13 @@ class Div3Lookup(models.lookups.Lookup):
class Div3Extract(models.lookups.Extract):
lookup_name = 'div3'
def as_sql(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs)
return '%s %%%% 3' % (lhs,), lhs_params
class Div3LookupWithExtract(Div3Lookup):
lookup_name = 'div3'
extract_class = Div3Extract
class YearLte(models.lookups.LessThanOrEqual):
"""
The purpose of this lookup is to efficiently compare the year of the field.
@@ -50,6 +47,8 @@ class YearLte(models.lookups.LessThanOrEqual):
class YearExtract(models.lookups.Extract):
lookup_name = 'year'
def as_sql(self, qn, connection):
lhs_sql, params = qn.compile(self.lhs)
return connection.ops.date_extract_sql('year', lhs_sql), params
@@ -61,12 +60,44 @@ class YearExtract(models.lookups.Extract):
def get_lookup(self, lookup):
if lookup == 'lte':
return YearLte
elif lookup == 'exact':
return YearExact
else:
return super(YearExtract, self).get_lookup(lookup)
class YearWithExtract(models.lookups.Year):
extract_class = YearExtract
class YearExact(models.lookups.Lookup):
def as_sql(self, qn, connection):
# We will need to skip the extract part, and instead go
# directly with the originating field, that is self.lhs.lhs
lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(qn, connection)
# Note that we must be careful so that we have params in the
# same order as we have the parts in the SQL.
params = []
params.extend(lhs_params)
params.extend(rhs_params)
params.extend(lhs_params)
params.extend(rhs_params)
# We use PostgreSQL specific SQL here. Note that we must do the
# conversions in SQL instead of in Python to support F() references.
return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
"AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
@add_implementation(YearExact, 'mysql')
def mysql_year_exact(node, qn, connection):
lhs_sql, lhs_params = node.process_lhs(qn, connection, node.lhs.lhs)
rhs_sql, rhs_params = node.process_rhs(qn, connection)
params = []
params.extend(lhs_params)
params.extend(rhs_params)
params.extend(lhs_params)
params.extend(rhs_params)
return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
"AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
class InMonth(models.lookups.Lookup):
@@ -158,7 +189,7 @@ class LookupTests(TestCase):
models.Field.register_lookup(AnotherEqual)
try:
@add_implementation(AnotherEqual, connection.vendor)
def custom_eq_sql(node, compiler):
def custom_eq_sql(node, qn, connection):
return '1 = 1', []
self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query))
@@ -167,7 +198,7 @@ class LookupTests(TestCase):
[a1, a2, a3, a4], lambda x: x)
@add_implementation(AnotherEqual, connection.vendor)
def another_custom_eq_sql(node, compiler):
def another_custom_eq_sql(node, qn, connection):
# If you need to override one method, it seems this is the best
# option.
node = copy(node)
@@ -176,7 +207,7 @@ class LookupTests(TestCase):
def get_rhs_op(self, connection, rhs):
return ' <> %s'
node.__class__ = OverriddenAnotherEqual
return node.as_sql(compiler, compiler.connection)
return node.as_sql(qn, connection)
self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query))
self.assertQuerysetEqual(
Author.objects.filter(name__anotherequal='a1').order_by('name'),
@@ -186,13 +217,16 @@ class LookupTests(TestCase):
models.Field._unregister_lookup(AnotherEqual)
def test_div3_extract(self):
models.IntegerField.register_lookup(Div3LookupWithExtract)
models.IntegerField.register_lookup(Div3Extract)
try:
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
baseqs.filter(age__div3=2),
[a2], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__lte=3),
[a1, a2, a3, a4], lambda x: x)
@@ -200,19 +234,19 @@ class LookupTests(TestCase):
baseqs.filter(age__div3__in=[0, 2]),
[a2, a3], lambda x: x)
finally:
models.IntegerField._unregister_lookup(Div3LookupWithExtract)
models.IntegerField._unregister_lookup(Div3Extract)
class YearLteTests(TestCase):
def setUp(self):
models.DateField.register_lookup(YearWithExtract)
models.DateField.register_lookup(YearExtract)
self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
def tearDown(self):
models.DateField._unregister_lookup(YearWithExtract)
models.DateField._unregister_lookup(YearExtract)
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
def test_year_lte(self):
@@ -220,6 +254,11 @@ class YearLteTests(TestCase):
self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lte=2012),
[self.a1, self.a2, self.a3, self.a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(birthdate__year=2012),
[self.a2, self.a3, self.a4], lambda x: x)
self.assertNotIn('BETWEEN', str(baseqs.filter(birthdate__year=2012).query))
self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lte=2011),
[self.a1], lambda x: x)
@@ -253,3 +292,12 @@ class YearLteTests(TestCase):
'<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query))
self.assertIn(
'-12-31', str(baseqs.filter(birthdate__year__lte=2011).query))
@unittest.skipUnless(connection.vendor == 'mysql', 'MySQL specific SQL used')
def test_mysql_year_exact(self):
self.assertQuerysetEqual(
Author.objects.filter(birthdate__year=2012).order_by('name'),
[self.a2, self.a3, self.a4], lambda x: x)
self.assertIn(
'concat(',
str(Author.objects.filter(birthdate__year=2012).query))