mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			650 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			650 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import time
 | |
| import unittest
 | |
| from datetime import date, datetime
 | |
| 
 | |
| from django.core.exceptions import FieldError
 | |
| from django.db import connection, models
 | |
| from django.test import SimpleTestCase, TestCase, override_settings
 | |
| from django.test.utils import register_lookup
 | |
| from django.utils import timezone
 | |
| 
 | |
| from .models import Article, Author, MySQLUnixTimestamp
 | |
| 
 | |
| 
 | |
| class Div3Lookup(models.Lookup):
 | |
|     lookup_name = "div3"
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         lhs, params = self.process_lhs(compiler, connection)
 | |
|         rhs, rhs_params = self.process_rhs(compiler, connection)
 | |
|         params.extend(rhs_params)
 | |
|         return "(%s) %%%% 3 = %s" % (lhs, rhs), params
 | |
| 
 | |
|     def as_oracle(self, compiler, connection):
 | |
|         lhs, params = self.process_lhs(compiler, connection)
 | |
|         rhs, rhs_params = self.process_rhs(compiler, connection)
 | |
|         params.extend(rhs_params)
 | |
|         return "mod(%s, 3) = %s" % (lhs, rhs), params
 | |
| 
 | |
| 
 | |
| class Div3Transform(models.Transform):
 | |
|     lookup_name = "div3"
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         lhs, lhs_params = compiler.compile(self.lhs)
 | |
|         return "(%s) %%%% 3" % lhs, lhs_params
 | |
| 
 | |
|     def as_oracle(self, compiler, connection, **extra_context):
 | |
|         lhs, lhs_params = compiler.compile(self.lhs)
 | |
|         return "mod(%s, 3)" % lhs, lhs_params
 | |
| 
 | |
| 
 | |
| class Div3BilateralTransform(Div3Transform):
 | |
|     bilateral = True
 | |
| 
 | |
| 
 | |
| class Mult3BilateralTransform(models.Transform):
 | |
|     bilateral = True
 | |
|     lookup_name = "mult3"
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         lhs, lhs_params = compiler.compile(self.lhs)
 | |
|         return "3 * (%s)" % lhs, lhs_params
 | |
| 
 | |
| 
 | |
| class LastDigitTransform(models.Transform):
 | |
|     lookup_name = "lastdigit"
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         lhs, lhs_params = compiler.compile(self.lhs)
 | |
|         return "SUBSTR(CAST(%s AS CHAR(2)), 2, 1)" % lhs, lhs_params
 | |
| 
 | |
| 
 | |
| class UpperBilateralTransform(models.Transform):
 | |
|     bilateral = True
 | |
|     lookup_name = "upper"
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         lhs, lhs_params = compiler.compile(self.lhs)
 | |
|         return "UPPER(%s)" % lhs, lhs_params
 | |
| 
 | |
| 
 | |
| class YearTransform(models.Transform):
 | |
|     # Use a name that avoids collision with the built-in year lookup.
 | |
|     lookup_name = "testyear"
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         lhs_sql, params = compiler.compile(self.lhs)
 | |
|         return connection.ops.date_extract_sql("year", lhs_sql, params)
 | |
| 
 | |
|     @property
 | |
|     def output_field(self):
 | |
|         return models.IntegerField()
 | |
| 
 | |
| 
 | |
| @YearTransform.register_lookup
 | |
| class YearExact(models.lookups.Lookup):
 | |
|     lookup_name = "exact"
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         # We will need to skip the extract part, and instead go
 | |
|         # directly with the originating field, that is self.lhs.lhs
 | |
|         lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs)
 | |
|         rhs_sql, rhs_params = self.process_rhs(compiler, connection)
 | |
|         # Note that we must be careful so that we have params in the
 | |
|         # same order as we have the parts in the SQL.
 | |
|         params = lhs_params + rhs_params + lhs_params + rhs_params
 | |
|         # We use PostgreSQL specific SQL here. Note that we must do the
 | |
|         # conversions in SQL instead of in Python to support F() references.
 | |
|         return (
 | |
|             "%(lhs)s >= (%(rhs)s || '-01-01')::date "
 | |
|             "AND %(lhs)s <= (%(rhs)s || '-12-31')::date"
 | |
|             % {"lhs": lhs_sql, "rhs": rhs_sql},
 | |
|             params,
 | |
|         )
 | |
| 
 | |
| 
 | |
| @YearTransform.register_lookup
 | |
| class YearLte(models.lookups.LessThanOrEqual):
 | |
|     """
 | |
|     The purpose of this lookup is to efficiently compare the year of the field.
 | |
|     """
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         # Skip the YearTransform above us (no possibility for efficient
 | |
|         # lookup otherwise).
 | |
|         real_lhs = self.lhs.lhs
 | |
|         lhs_sql, params = self.process_lhs(compiler, connection, real_lhs)
 | |
|         rhs_sql, rhs_params = self.process_rhs(compiler, connection)
 | |
|         params.extend(rhs_params)
 | |
|         # Build SQL where the integer year is concatenated with last month
 | |
|         # and day, then convert that to date. (We try to have SQL like:
 | |
|         #     WHERE somecol <= '2013-12-31')
 | |
|         # but also make it work if the rhs_sql is field reference.
 | |
|         return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
 | |
| 
 | |
| 
 | |
| class Exactly(models.lookups.Exact):
 | |
|     """
 | |
|     This lookup is used to test lookup registration.
 | |
|     """
 | |
| 
 | |
|     lookup_name = "exactly"
 | |
| 
 | |
|     def get_rhs_op(self, connection, rhs):
 | |
|         return connection.operators["exact"] % rhs
 | |
| 
 | |
| 
 | |
| class SQLFuncMixin:
 | |
|     def as_sql(self, compiler, connection):
 | |
|         return "%s()" % self.name, []
 | |
| 
 | |
|     @property
 | |
|     def output_field(self):
 | |
|         return CustomField()
 | |
| 
 | |
| 
 | |
| class SQLFuncLookup(SQLFuncMixin, models.Lookup):
 | |
|     def __init__(self, name, *args, **kwargs):
 | |
|         super().__init__(*args, **kwargs)
 | |
|         self.name = name
 | |
| 
 | |
| 
 | |
| class SQLFuncTransform(SQLFuncMixin, models.Transform):
 | |
|     def __init__(self, name, *args, **kwargs):
 | |
|         super().__init__(*args, **kwargs)
 | |
|         self.name = name
 | |
| 
 | |
| 
 | |
| class SQLFuncFactory:
 | |
|     def __init__(self, key, name):
 | |
|         self.key = key
 | |
|         self.name = name
 | |
| 
 | |
|     def __call__(self, *args, **kwargs):
 | |
|         if self.key == "lookupfunc":
 | |
|             return SQLFuncLookup(self.name, *args, **kwargs)
 | |
|         return SQLFuncTransform(self.name, *args, **kwargs)
 | |
| 
 | |
| 
 | |
| class CustomField(models.TextField):
 | |
|     def get_lookup(self, lookup_name):
 | |
|         if lookup_name.startswith("lookupfunc_"):
 | |
|             key, name = lookup_name.split("_", 1)
 | |
|             return SQLFuncFactory(key, name)
 | |
|         return super().get_lookup(lookup_name)
 | |
| 
 | |
|     def get_transform(self, lookup_name):
 | |
|         if lookup_name.startswith("transformfunc_"):
 | |
|             key, name = lookup_name.split("_", 1)
 | |
|             return SQLFuncFactory(key, name)
 | |
|         return super().get_transform(lookup_name)
 | |
| 
 | |
| 
 | |
| class CustomModel(models.Model):
 | |
|     field = CustomField()
 | |
| 
 | |
| 
 | |
| # We will register this class temporarily in the test method.
 | |
| 
 | |
| 
 | |
| class InMonth(models.lookups.Lookup):
 | |
|     """
 | |
|     InMonth matches if the column's month is the same as value's month.
 | |
|     """
 | |
| 
 | |
|     lookup_name = "inmonth"
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         lhs, lhs_params = self.process_lhs(compiler, connection)
 | |
|         rhs, rhs_params = self.process_rhs(compiler, connection)
 | |
|         # We need to be careful so that we get the params in right
 | |
|         # places.
 | |
|         params = lhs_params + rhs_params + lhs_params + rhs_params
 | |
|         return (
 | |
|             "%s >= date_trunc('month', %s) and "
 | |
|             "%s < date_trunc('month', %s) + interval '1 months'" % (lhs, rhs, lhs, rhs),
 | |
|             params,
 | |
|         )
 | |
| 
 | |
| 
 | |
| class DateTimeTransform(models.Transform):
 | |
|     lookup_name = "as_datetime"
 | |
| 
 | |
|     @property
 | |
|     def output_field(self):
 | |
|         return models.DateTimeField()
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         lhs, params = compiler.compile(self.lhs)
 | |
|         return "from_unixtime({})".format(lhs), params
 | |
| 
 | |
| 
 | |
| class LookupTests(TestCase):
 | |
|     def test_custom_name_lookup(self):
 | |
|         a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
 | |
|         Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
 | |
|         with register_lookup(models.DateField, YearTransform), register_lookup(
 | |
|             models.DateField, YearTransform, lookup_name="justtheyear"
 | |
|         ), register_lookup(YearTransform, Exactly), register_lookup(
 | |
|             YearTransform, Exactly, lookup_name="isactually"
 | |
|         ):
 | |
|             qs1 = Author.objects.filter(birthdate__testyear__exactly=1981)
 | |
|             qs2 = Author.objects.filter(birthdate__justtheyear__isactually=1981)
 | |
|             self.assertSequenceEqual(qs1, [a1])
 | |
|             self.assertSequenceEqual(qs2, [a1])
 | |
| 
 | |
|     def test_custom_exact_lookup_none_rhs(self):
 | |
|         """
 | |
|         __exact=None is transformed to __isnull=True if a custom lookup class
 | |
|         with lookup_name != 'exact' is registered as the `exact` lookup.
 | |
|         """
 | |
|         field = Author._meta.get_field("birthdate")
 | |
|         OldExactLookup = field.get_lookup("exact")
 | |
|         author = Author.objects.create(name="author", birthdate=None)
 | |
|         try:
 | |
|             field.register_lookup(Exactly, "exact")
 | |
|             self.assertEqual(Author.objects.get(birthdate__exact=None), author)
 | |
|         finally:
 | |
|             field.register_lookup(OldExactLookup, "exact")
 | |
| 
 | |
|     def test_basic_lookup(self):
 | |
|         a1 = Author.objects.create(name="a1", age=1)
 | |
|         a2 = Author.objects.create(name="a2", age=2)
 | |
|         a3 = Author.objects.create(name="a3", age=3)
 | |
|         a4 = Author.objects.create(name="a4", age=4)
 | |
|         with register_lookup(models.IntegerField, Div3Lookup):
 | |
|             self.assertSequenceEqual(Author.objects.filter(age__div3=0), [a3])
 | |
|             self.assertSequenceEqual(
 | |
|                 Author.objects.filter(age__div3=1).order_by("age"), [a1, a4]
 | |
|             )
 | |
|             self.assertSequenceEqual(Author.objects.filter(age__div3=2), [a2])
 | |
|             self.assertSequenceEqual(Author.objects.filter(age__div3=3), [])
 | |
| 
 | |
|     @unittest.skipUnless(
 | |
|         connection.vendor == "postgresql", "PostgreSQL specific SQL used"
 | |
|     )
 | |
|     def test_birthdate_month(self):
 | |
|         a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
 | |
|         a2 = Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
 | |
|         a3 = Author.objects.create(name="a3", birthdate=date(2012, 1, 31))
 | |
|         a4 = Author.objects.create(name="a4", birthdate=date(2012, 3, 1))
 | |
|         with register_lookup(models.DateField, InMonth):
 | |
|             self.assertSequenceEqual(
 | |
|                 Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)), [a3]
 | |
|             )
 | |
|             self.assertSequenceEqual(
 | |
|                 Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)), [a2]
 | |
|             )
 | |
|             self.assertSequenceEqual(
 | |
|                 Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)), [a1]
 | |
|             )
 | |
|             self.assertSequenceEqual(
 | |
|                 Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)), [a4]
 | |
|             )
 | |
|             self.assertSequenceEqual(
 | |
|                 Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)), []
 | |
|             )
 | |
| 
 | |
|     def test_div3_extract(self):
 | |
|         with register_lookup(models.IntegerField, Div3Transform):
 | |
|             a1 = Author.objects.create(name="a1", age=1)
 | |
|             a2 = Author.objects.create(name="a2", age=2)
 | |
|             a3 = Author.objects.create(name="a3", age=3)
 | |
|             a4 = Author.objects.create(name="a4", age=4)
 | |
|             baseqs = Author.objects.order_by("name")
 | |
|             self.assertSequenceEqual(baseqs.filter(age__div3=2), [a2])
 | |
|             self.assertSequenceEqual(baseqs.filter(age__div3__lte=3), [a1, a2, a3, a4])
 | |
|             self.assertSequenceEqual(baseqs.filter(age__div3__in=[0, 2]), [a2, a3])
 | |
|             self.assertSequenceEqual(baseqs.filter(age__div3__in=[2, 4]), [a2])
 | |
|             self.assertSequenceEqual(baseqs.filter(age__div3__gte=3), [])
 | |
|             self.assertSequenceEqual(
 | |
|                 baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4]
 | |
|             )
 | |
| 
 | |
|     def test_foreignobject_lookup_registration(self):
 | |
|         field = Article._meta.get_field("author")
 | |
| 
 | |
|         with register_lookup(models.ForeignObject, Exactly):
 | |
|             self.assertIs(field.get_lookup("exactly"), Exactly)
 | |
| 
 | |
|         # ForeignObject should ignore regular Field lookups
 | |
|         with register_lookup(models.Field, Exactly):
 | |
|             self.assertIsNone(field.get_lookup("exactly"))
 | |
| 
 | |
|     def test_lookups_caching(self):
 | |
|         field = Article._meta.get_field("author")
 | |
| 
 | |
|         # clear and re-cache
 | |
|         field.get_lookups.cache_clear()
 | |
|         self.assertNotIn("exactly", field.get_lookups())
 | |
| 
 | |
|         # registration should bust the cache
 | |
|         with register_lookup(models.ForeignObject, Exactly):
 | |
|             # getting the lookups again should re-cache
 | |
|             self.assertIn("exactly", field.get_lookups())
 | |
|         # Unregistration should bust the cache.
 | |
|         self.assertNotIn("exactly", field.get_lookups())
 | |
| 
 | |
| 
 | |
| class BilateralTransformTests(TestCase):
 | |
|     def test_bilateral_upper(self):
 | |
|         with register_lookup(models.CharField, UpperBilateralTransform):
 | |
|             author1 = Author.objects.create(name="Doe")
 | |
|             author2 = Author.objects.create(name="doe")
 | |
|             author3 = Author.objects.create(name="Foo")
 | |
|             self.assertCountEqual(
 | |
|                 Author.objects.filter(name__upper="doe"),
 | |
|                 [author1, author2],
 | |
|             )
 | |
|             self.assertSequenceEqual(
 | |
|                 Author.objects.filter(name__upper__contains="f"),
 | |
|                 [author3],
 | |
|             )
 | |
| 
 | |
|     def test_bilateral_inner_qs(self):
 | |
|         with register_lookup(models.CharField, UpperBilateralTransform):
 | |
|             msg = "Bilateral transformations on nested querysets are not implemented."
 | |
|             with self.assertRaisesMessage(NotImplementedError, msg):
 | |
|                 Author.objects.filter(
 | |
|                     name__upper__in=Author.objects.values_list("name")
 | |
|                 )
 | |
| 
 | |
|     def test_bilateral_multi_value(self):
 | |
|         with register_lookup(models.CharField, UpperBilateralTransform):
 | |
|             Author.objects.bulk_create(
 | |
|                 [
 | |
|                     Author(name="Foo"),
 | |
|                     Author(name="Bar"),
 | |
|                     Author(name="Ray"),
 | |
|                 ]
 | |
|             )
 | |
|             self.assertQuerysetEqual(
 | |
|                 Author.objects.filter(name__upper__in=["foo", "bar", "doe"]).order_by(
 | |
|                     "name"
 | |
|                 ),
 | |
|                 ["Bar", "Foo"],
 | |
|                 lambda a: a.name,
 | |
|             )
 | |
| 
 | |
|     def test_div3_bilateral_extract(self):
 | |
|         with register_lookup(models.IntegerField, Div3BilateralTransform):
 | |
|             a1 = Author.objects.create(name="a1", age=1)
 | |
|             a2 = Author.objects.create(name="a2", age=2)
 | |
|             a3 = Author.objects.create(name="a3", age=3)
 | |
|             a4 = Author.objects.create(name="a4", age=4)
 | |
|             baseqs = Author.objects.order_by("name")
 | |
|             self.assertSequenceEqual(baseqs.filter(age__div3=2), [a2])
 | |
|             self.assertSequenceEqual(baseqs.filter(age__div3__lte=3), [a3])
 | |
|             self.assertSequenceEqual(baseqs.filter(age__div3__in=[0, 2]), [a2, a3])
 | |
|             self.assertSequenceEqual(baseqs.filter(age__div3__in=[2, 4]), [a1, a2, a4])
 | |
|             self.assertSequenceEqual(baseqs.filter(age__div3__gte=3), [a1, a2, a3, a4])
 | |
|             self.assertSequenceEqual(
 | |
|                 baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4]
 | |
|             )
 | |
| 
 | |
|     def test_bilateral_order(self):
 | |
|         with register_lookup(
 | |
|             models.IntegerField, Mult3BilateralTransform, Div3BilateralTransform
 | |
|         ):
 | |
|             a1 = Author.objects.create(name="a1", age=1)
 | |
|             a2 = Author.objects.create(name="a2", age=2)
 | |
|             a3 = Author.objects.create(name="a3", age=3)
 | |
|             a4 = Author.objects.create(name="a4", age=4)
 | |
|             baseqs = Author.objects.order_by("name")
 | |
| 
 | |
|             # mult3__div3 always leads to 0
 | |
|             self.assertSequenceEqual(
 | |
|                 baseqs.filter(age__mult3__div3=42), [a1, a2, a3, a4]
 | |
|             )
 | |
|             self.assertSequenceEqual(baseqs.filter(age__div3__mult3=42), [a3])
 | |
| 
 | |
|     def test_transform_order_by(self):
 | |
|         with register_lookup(models.IntegerField, LastDigitTransform):
 | |
|             a1 = Author.objects.create(name="a1", age=11)
 | |
|             a2 = Author.objects.create(name="a2", age=23)
 | |
|             a3 = Author.objects.create(name="a3", age=32)
 | |
|             a4 = Author.objects.create(name="a4", age=40)
 | |
|             qs = Author.objects.order_by("age__lastdigit")
 | |
|             self.assertSequenceEqual(qs, [a4, a1, a3, a2])
 | |
| 
 | |
|     def test_bilateral_fexpr(self):
 | |
|         with register_lookup(models.IntegerField, Mult3BilateralTransform):
 | |
|             a1 = Author.objects.create(name="a1", age=1, average_rating=3.2)
 | |
|             a2 = Author.objects.create(name="a2", age=2, average_rating=0.5)
 | |
|             a3 = Author.objects.create(name="a3", age=3, average_rating=1.5)
 | |
|             a4 = Author.objects.create(name="a4", age=4)
 | |
|             baseqs = Author.objects.order_by("name")
 | |
|             self.assertSequenceEqual(
 | |
|                 baseqs.filter(age__mult3=models.F("age")), [a1, a2, a3, a4]
 | |
|             )
 | |
|             # Same as age >= average_rating
 | |
|             self.assertSequenceEqual(
 | |
|                 baseqs.filter(age__mult3__gte=models.F("average_rating")), [a2, a3]
 | |
|             )
 | |
| 
 | |
| 
 | |
| @override_settings(USE_TZ=True)
 | |
| class DateTimeLookupTests(TestCase):
 | |
|     @unittest.skipUnless(connection.vendor == "mysql", "MySQL specific SQL used")
 | |
|     def test_datetime_output_field(self):
 | |
|         with register_lookup(models.PositiveIntegerField, DateTimeTransform):
 | |
|             ut = MySQLUnixTimestamp.objects.create(timestamp=time.time())
 | |
|             y2k = timezone.make_aware(datetime(2000, 1, 1))
 | |
|             self.assertSequenceEqual(
 | |
|                 MySQLUnixTimestamp.objects.filter(timestamp__as_datetime__gt=y2k), [ut]
 | |
|             )
 | |
| 
 | |
| 
 | |
| class YearLteTests(TestCase):
 | |
|     @classmethod
 | |
|     def setUpTestData(cls):
 | |
|         cls.a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
 | |
|         cls.a2 = Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
 | |
|         cls.a3 = Author.objects.create(name="a3", birthdate=date(2012, 1, 31))
 | |
|         cls.a4 = Author.objects.create(name="a4", birthdate=date(2012, 3, 1))
 | |
| 
 | |
|     def setUp(self):
 | |
|         models.DateField.register_lookup(YearTransform)
 | |
| 
 | |
|     def tearDown(self):
 | |
|         models.DateField._unregister_lookup(YearTransform)
 | |
| 
 | |
|     @unittest.skipUnless(
 | |
|         connection.vendor == "postgresql", "PostgreSQL specific SQL used"
 | |
|     )
 | |
|     def test_year_lte(self):
 | |
|         baseqs = Author.objects.order_by("name")
 | |
|         self.assertSequenceEqual(
 | |
|             baseqs.filter(birthdate__testyear__lte=2012),
 | |
|             [self.a1, self.a2, self.a3, self.a4],
 | |
|         )
 | |
|         self.assertSequenceEqual(
 | |
|             baseqs.filter(birthdate__testyear=2012), [self.a2, self.a3, self.a4]
 | |
|         )
 | |
| 
 | |
|         self.assertNotIn("BETWEEN", str(baseqs.filter(birthdate__testyear=2012).query))
 | |
|         self.assertSequenceEqual(
 | |
|             baseqs.filter(birthdate__testyear__lte=2011), [self.a1]
 | |
|         )
 | |
|         # The non-optimized version works, too.
 | |
|         self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lt=2012), [self.a1])
 | |
| 
 | |
|     @unittest.skipUnless(
 | |
|         connection.vendor == "postgresql", "PostgreSQL specific SQL used"
 | |
|     )
 | |
|     def test_year_lte_fexpr(self):
 | |
|         self.a2.age = 2011
 | |
|         self.a2.save()
 | |
|         self.a3.age = 2012
 | |
|         self.a3.save()
 | |
|         self.a4.age = 2013
 | |
|         self.a4.save()
 | |
|         baseqs = Author.objects.order_by("name")
 | |
|         self.assertSequenceEqual(
 | |
|             baseqs.filter(birthdate__testyear__lte=models.F("age")), [self.a3, self.a4]
 | |
|         )
 | |
|         self.assertSequenceEqual(
 | |
|             baseqs.filter(birthdate__testyear__lt=models.F("age")), [self.a4]
 | |
|         )
 | |
| 
 | |
|     def test_year_lte_sql(self):
 | |
|         # This test will just check the generated SQL for __lte. This
 | |
|         # doesn't require running on PostgreSQL and spots the most likely
 | |
|         # error - not running YearLte SQL at all.
 | |
|         baseqs = Author.objects.order_by("name")
 | |
|         self.assertIn(
 | |
|             "<= (2011 || ", str(baseqs.filter(birthdate__testyear__lte=2011).query)
 | |
|         )
 | |
|         self.assertIn("-12-31", str(baseqs.filter(birthdate__testyear__lte=2011).query))
 | |
| 
 | |
|     def test_postgres_year_exact(self):
 | |
|         baseqs = Author.objects.order_by("name")
 | |
|         self.assertIn("= (2011 || ", str(baseqs.filter(birthdate__testyear=2011).query))
 | |
|         self.assertIn("-12-31", str(baseqs.filter(birthdate__testyear=2011).query))
 | |
| 
 | |
|     def test_custom_implementation_year_exact(self):
 | |
|         try:
 | |
|             # Two ways to add a customized implementation for different backends:
 | |
|             # First is MonkeyPatch of the class.
 | |
|             def as_custom_sql(self, compiler, connection):
 | |
|                 lhs_sql, lhs_params = self.process_lhs(
 | |
|                     compiler, connection, self.lhs.lhs
 | |
|                 )
 | |
|                 rhs_sql, rhs_params = self.process_rhs(compiler, connection)
 | |
|                 params = lhs_params + rhs_params + lhs_params + rhs_params
 | |
|                 return (
 | |
|                     "%(lhs)s >= "
 | |
|                     "str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
 | |
|                     "AND %(lhs)s <= "
 | |
|                     "str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')"
 | |
|                     % {"lhs": lhs_sql, "rhs": rhs_sql},
 | |
|                     params,
 | |
|                 )
 | |
| 
 | |
|             setattr(YearExact, "as_" + connection.vendor, as_custom_sql)
 | |
|             self.assertIn(
 | |
|                 "concat(", str(Author.objects.filter(birthdate__testyear=2012).query)
 | |
|             )
 | |
|         finally:
 | |
|             delattr(YearExact, "as_" + connection.vendor)
 | |
|         try:
 | |
|             # The other way is to subclass the original lookup and register the
 | |
|             # subclassed lookup instead of the original.
 | |
|             class CustomYearExact(YearExact):
 | |
|                 # This method should be named "as_mysql" for MySQL,
 | |
|                 # "as_postgresql" for postgres and so on, but as we don't know
 | |
|                 # which DB we are running on, we need to use setattr.
 | |
|                 def as_custom_sql(self, compiler, connection):
 | |
|                     lhs_sql, lhs_params = self.process_lhs(
 | |
|                         compiler, connection, self.lhs.lhs
 | |
|                     )
 | |
|                     rhs_sql, rhs_params = self.process_rhs(compiler, connection)
 | |
|                     params = lhs_params + rhs_params + lhs_params + rhs_params
 | |
|                     return (
 | |
|                         "%(lhs)s >= "
 | |
|                         "str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
 | |
|                         "AND %(lhs)s <= "
 | |
|                         "str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')"
 | |
|                         % {"lhs": lhs_sql, "rhs": rhs_sql},
 | |
|                         params,
 | |
|                     )
 | |
| 
 | |
|             setattr(
 | |
|                 CustomYearExact,
 | |
|                 "as_" + connection.vendor,
 | |
|                 CustomYearExact.as_custom_sql,
 | |
|             )
 | |
|             YearTransform.register_lookup(CustomYearExact)
 | |
|             self.assertIn(
 | |
|                 "CONCAT(", str(Author.objects.filter(birthdate__testyear=2012).query)
 | |
|             )
 | |
|         finally:
 | |
|             YearTransform._unregister_lookup(CustomYearExact)
 | |
|             YearTransform.register_lookup(YearExact)
 | |
| 
 | |
| 
 | |
| class TrackCallsYearTransform(YearTransform):
 | |
|     # Use a name that avoids collision with the built-in year lookup.
 | |
|     lookup_name = "testyear"
 | |
|     call_order = []
 | |
| 
 | |
|     def as_sql(self, compiler, connection):
 | |
|         lhs_sql, params = compiler.compile(self.lhs)
 | |
|         return connection.ops.date_extract_sql("year", lhs_sql), params
 | |
| 
 | |
|     @property
 | |
|     def output_field(self):
 | |
|         return models.IntegerField()
 | |
| 
 | |
|     def get_lookup(self, lookup_name):
 | |
|         self.call_order.append("lookup")
 | |
|         return super().get_lookup(lookup_name)
 | |
| 
 | |
|     def get_transform(self, lookup_name):
 | |
|         self.call_order.append("transform")
 | |
|         return super().get_transform(lookup_name)
 | |
| 
 | |
| 
 | |
| class LookupTransformCallOrderTests(SimpleTestCase):
 | |
|     def test_call_order(self):
 | |
|         with register_lookup(models.DateField, TrackCallsYearTransform):
 | |
|             # junk lookup - tries lookup, then transform, then fails
 | |
|             msg = (
 | |
|                 "Unsupported lookup 'junk' for IntegerField or join on the field not "
 | |
|                 "permitted."
 | |
|             )
 | |
|             with self.assertRaisesMessage(FieldError, msg):
 | |
|                 Author.objects.filter(birthdate__testyear__junk=2012)
 | |
|             self.assertEqual(
 | |
|                 TrackCallsYearTransform.call_order, ["lookup", "transform"]
 | |
|             )
 | |
|             TrackCallsYearTransform.call_order = []
 | |
|             # junk transform - tries transform only, then fails
 | |
|             with self.assertRaisesMessage(FieldError, msg):
 | |
|                 Author.objects.filter(birthdate__testyear__junk__more_junk=2012)
 | |
|             self.assertEqual(TrackCallsYearTransform.call_order, ["transform"])
 | |
|             TrackCallsYearTransform.call_order = []
 | |
|             # Just getting the year (implied __exact) - lookup only
 | |
|             Author.objects.filter(birthdate__testyear=2012)
 | |
|             self.assertEqual(TrackCallsYearTransform.call_order, ["lookup"])
 | |
|             TrackCallsYearTransform.call_order = []
 | |
|             # Just getting the year (explicit __exact) - lookup only
 | |
|             Author.objects.filter(birthdate__testyear__exact=2012)
 | |
|             self.assertEqual(TrackCallsYearTransform.call_order, ["lookup"])
 | |
| 
 | |
| 
 | |
| class CustomisedMethodsTests(SimpleTestCase):
 | |
|     def test_overridden_get_lookup(self):
 | |
|         q = CustomModel.objects.filter(field__lookupfunc_monkeys=3)
 | |
|         self.assertIn("monkeys()", str(q.query))
 | |
| 
 | |
|     def test_overridden_get_transform(self):
 | |
|         q = CustomModel.objects.filter(field__transformfunc_banana=3)
 | |
|         self.assertIn("banana()", str(q.query))
 | |
| 
 | |
|     def test_overridden_get_lookup_chain(self):
 | |
|         q = CustomModel.objects.filter(
 | |
|             field__transformfunc_banana__lookupfunc_elephants=3
 | |
|         )
 | |
|         self.assertIn("elephants()", str(q.query))
 | |
| 
 | |
|     def test_overridden_get_transform_chain(self):
 | |
|         q = CustomModel.objects.filter(
 | |
|             field__transformfunc_banana__transformfunc_pear=3
 | |
|         )
 | |
|         self.assertIn("pear()", str(q.query))
 | |
| 
 | |
| 
 | |
| class SubqueryTransformTests(TestCase):
 | |
|     def test_subquery_usage(self):
 | |
|         with register_lookup(models.IntegerField, Div3Transform):
 | |
|             Author.objects.create(name="a1", age=1)
 | |
|             a2 = Author.objects.create(name="a2", age=2)
 | |
|             Author.objects.create(name="a3", age=3)
 | |
|             Author.objects.create(name="a4", age=4)
 | |
|             qs = Author.objects.order_by("name").filter(
 | |
|                 id__in=Author.objects.filter(age__div3=2)
 | |
|             )
 | |
|             self.assertSequenceEqual(qs, [a2])
 |