mirror of
https://github.com/django/django.git
synced 2025-10-31 09:41:08 +00:00
Added documentation, polished implementation
This commit is contained in:
@@ -20,16 +20,13 @@ class Div3Lookup(models.lookups.Lookup):
|
||||
|
||||
|
||||
class Div3Extract(models.lookups.Extract):
|
||||
lookup_name = 'div3'
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, lhs_params = qn.compile(self.lhs)
|
||||
return '%s %%%% 3' % (lhs,), lhs_params
|
||||
|
||||
|
||||
class Div3LookupWithExtract(Div3Lookup):
|
||||
lookup_name = 'div3'
|
||||
extract_class = Div3Extract
|
||||
|
||||
|
||||
class YearLte(models.lookups.LessThanOrEqual):
|
||||
"""
|
||||
The purpose of this lookup is to efficiently compare the year of the field.
|
||||
@@ -50,6 +47,8 @@ class YearLte(models.lookups.LessThanOrEqual):
|
||||
|
||||
|
||||
class YearExtract(models.lookups.Extract):
|
||||
lookup_name = 'year'
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs_sql, params = qn.compile(self.lhs)
|
||||
return connection.ops.date_extract_sql('year', lhs_sql), params
|
||||
@@ -61,12 +60,44 @@ class YearExtract(models.lookups.Extract):
|
||||
def get_lookup(self, lookup):
|
||||
if lookup == 'lte':
|
||||
return YearLte
|
||||
elif lookup == 'exact':
|
||||
return YearExact
|
||||
else:
|
||||
return super(YearExtract, self).get_lookup(lookup)
|
||||
|
||||
|
||||
class YearWithExtract(models.lookups.Year):
|
||||
extract_class = YearExtract
|
||||
class YearExact(models.lookups.Lookup):
|
||||
def as_sql(self, qn, connection):
|
||||
# We will need to skip the extract part, and instead go
|
||||
# directly with the originating field, that is self.lhs.lhs
|
||||
lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
|
||||
rhs_sql, rhs_params = self.process_rhs(qn, connection)
|
||||
# Note that we must be careful so that we have params in the
|
||||
# same order as we have the parts in the SQL.
|
||||
params = []
|
||||
params.extend(lhs_params)
|
||||
params.extend(rhs_params)
|
||||
params.extend(lhs_params)
|
||||
params.extend(rhs_params)
|
||||
# We use PostgreSQL specific SQL here. Note that we must do the
|
||||
# conversions in SQL instead of in Python to support F() references.
|
||||
return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
|
||||
"AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
|
||||
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
|
||||
|
||||
|
||||
@add_implementation(YearExact, 'mysql')
|
||||
def mysql_year_exact(node, qn, connection):
|
||||
lhs_sql, lhs_params = node.process_lhs(qn, connection, node.lhs.lhs)
|
||||
rhs_sql, rhs_params = node.process_rhs(qn, connection)
|
||||
params = []
|
||||
params.extend(lhs_params)
|
||||
params.extend(rhs_params)
|
||||
params.extend(lhs_params)
|
||||
params.extend(rhs_params)
|
||||
return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
|
||||
"AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
|
||||
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
|
||||
|
||||
|
||||
class InMonth(models.lookups.Lookup):
|
||||
@@ -158,7 +189,7 @@ class LookupTests(TestCase):
|
||||
models.Field.register_lookup(AnotherEqual)
|
||||
try:
|
||||
@add_implementation(AnotherEqual, connection.vendor)
|
||||
def custom_eq_sql(node, compiler):
|
||||
def custom_eq_sql(node, qn, connection):
|
||||
return '1 = 1', []
|
||||
|
||||
self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query))
|
||||
@@ -167,7 +198,7 @@ class LookupTests(TestCase):
|
||||
[a1, a2, a3, a4], lambda x: x)
|
||||
|
||||
@add_implementation(AnotherEqual, connection.vendor)
|
||||
def another_custom_eq_sql(node, compiler):
|
||||
def another_custom_eq_sql(node, qn, connection):
|
||||
# If you need to override one method, it seems this is the best
|
||||
# option.
|
||||
node = copy(node)
|
||||
@@ -176,7 +207,7 @@ class LookupTests(TestCase):
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return ' <> %s'
|
||||
node.__class__ = OverriddenAnotherEqual
|
||||
return node.as_sql(compiler, compiler.connection)
|
||||
return node.as_sql(qn, connection)
|
||||
self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query))
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(name__anotherequal='a1').order_by('name'),
|
||||
@@ -186,13 +217,16 @@ class LookupTests(TestCase):
|
||||
models.Field._unregister_lookup(AnotherEqual)
|
||||
|
||||
def test_div3_extract(self):
|
||||
models.IntegerField.register_lookup(Div3LookupWithExtract)
|
||||
models.IntegerField.register_lookup(Div3Extract)
|
||||
try:
|
||||
a1 = Author.objects.create(name='a1', age=1)
|
||||
a2 = Author.objects.create(name='a2', age=2)
|
||||
a3 = Author.objects.create(name='a3', age=3)
|
||||
a4 = Author.objects.create(name='a4', age=4)
|
||||
baseqs = Author.objects.order_by('name')
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3=2),
|
||||
[a2], lambda x: x)
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3__lte=3),
|
||||
[a1, a2, a3, a4], lambda x: x)
|
||||
@@ -200,19 +234,19 @@ class LookupTests(TestCase):
|
||||
baseqs.filter(age__div3__in=[0, 2]),
|
||||
[a2, a3], lambda x: x)
|
||||
finally:
|
||||
models.IntegerField._unregister_lookup(Div3LookupWithExtract)
|
||||
models.IntegerField._unregister_lookup(Div3Extract)
|
||||
|
||||
|
||||
class YearLteTests(TestCase):
|
||||
def setUp(self):
|
||||
models.DateField.register_lookup(YearWithExtract)
|
||||
models.DateField.register_lookup(YearExtract)
|
||||
self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
|
||||
self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
|
||||
self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
|
||||
self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
|
||||
|
||||
def tearDown(self):
|
||||
models.DateField._unregister_lookup(YearWithExtract)
|
||||
models.DateField._unregister_lookup(YearExtract)
|
||||
|
||||
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
|
||||
def test_year_lte(self):
|
||||
@@ -220,6 +254,11 @@ class YearLteTests(TestCase):
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(birthdate__year__lte=2012),
|
||||
[self.a1, self.a2, self.a3, self.a4], lambda x: x)
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(birthdate__year=2012),
|
||||
[self.a2, self.a3, self.a4], lambda x: x)
|
||||
|
||||
self.assertNotIn('BETWEEN', str(baseqs.filter(birthdate__year=2012).query))
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(birthdate__year__lte=2011),
|
||||
[self.a1], lambda x: x)
|
||||
@@ -253,3 +292,12 @@ class YearLteTests(TestCase):
|
||||
'<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query))
|
||||
self.assertIn(
|
||||
'-12-31', str(baseqs.filter(birthdate__year__lte=2011).query))
|
||||
|
||||
@unittest.skipUnless(connection.vendor == 'mysql', 'MySQL specific SQL used')
|
||||
def test_mysql_year_exact(self):
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(birthdate__year=2012).order_by('name'),
|
||||
[self.a2, self.a3, self.a4], lambda x: x)
|
||||
self.assertIn(
|
||||
'concat(',
|
||||
str(Author.objects.filter(birthdate__year=2012).query))
|
||||
|
||||
Reference in New Issue
Block a user