diff --git a/django/db/models/fields/tuple_lookups.py b/django/db/models/fields/tuple_lookups.py index cdb3b47209..eb2d80b20f 100644 --- a/django/db/models/fields/tuple_lookups.py +++ b/django/db/models/fields/tuple_lookups.py @@ -1,6 +1,7 @@ import itertools from django.core.exceptions import EmptyResultSet +from django.db.models import Field from django.db.models.expressions import Func, Value from django.db.models.lookups import ( Exact, @@ -16,15 +17,19 @@ from django.db.models.sql.where import AND, OR, WhereNode class Tuple(Func): function = "" + output_field = Field() + + def __len__(self): + return len(self.source_expressions) + + def __iter__(self): + return iter(self.source_expressions) class TupleLookupMixin: def get_prep_lookup(self): - self.check_tuple_lookup() - return super().get_prep_lookup() - - def check_tuple_lookup(self): self.check_rhs_length_equals_lhs_length() + return self.rhs def check_rhs_length_equals_lhs_length(self): len_lhs = len(self.lhs) @@ -34,24 +39,30 @@ class TupleLookupMixin: f"must have {len_lhs} elements" ) - def as_sql(self, compiler, connection): - # e.g.: (a, b, c) == (x, y, z) as SQL: - # WHERE (a, b, c) = (x, y, z) - vals = [ + def get_prep_lhs(self): + if isinstance(self.lhs, (tuple, list)): + return Tuple(*self.lhs) + return super().get_prep_lhs() + + def process_lhs(self, compiler, connection, lhs=None): + sql, params = super().process_lhs(compiler, connection, lhs) + if not isinstance(self.lhs, Tuple): + sql = f"({sql})" + return sql, params + + def process_rhs(self, compiler, connection): + values = [ Value(val, output_field=col.output_field) for col, val in zip(self.lhs, self.rhs) ] - lookup_class = self.__class__.__bases__[-1] - lookup = lookup_class(Tuple(self.lhs), Tuple(*vals)) - return lookup.as_sql(compiler, connection) + return Tuple(*values).as_sql(compiler, connection) class TupleExact(TupleLookupMixin, Exact): def as_oracle(self, compiler, connection): # e.g.: (a, b, c) == (x, y, z) as SQL: # WHERE a = x AND b = y AND c = z - cols = self.lhs.get_cols() - lookups = [Exact(col, val) for col, val in zip(cols, self.rhs)] + lookups = [Exact(col, val) for col, val in zip(self.lhs, self.rhs)] root = WhereNode(lookups, connector=AND) return root.as_sql(compiler, connection) @@ -83,10 +94,9 @@ class TupleGreaterThan(TupleLookupMixin, GreaterThan): def as_oracle(self, compiler, connection): # e.g.: (a, b, c) > (x, y, z) as SQL: # WHERE a > x OR (a = x AND (b > y OR (b = y AND c > z))) - cols = self.lhs.get_cols() lookups = itertools.cycle([GreaterThan, Exact]) connectors = itertools.cycle([OR, AND]) - cols_list = [col for col in cols for _ in range(2)] + cols_list = [col for col in self.lhs for _ in range(2)] vals_list = [val for val in self.rhs for _ in range(2)] cols_iter = iter(cols_list[:-1]) vals_iter = iter(vals_list[:-1]) @@ -110,10 +120,9 @@ class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual): def as_oracle(self, compiler, connection): # e.g.: (a, b, c) >= (x, y, z) as SQL: # WHERE a > x OR (a = x AND (b > y OR (b = y AND (c > z OR c = z)))) - cols = self.lhs.get_cols() lookups = itertools.cycle([GreaterThan, Exact]) connectors = itertools.cycle([OR, AND]) - cols_list = [col for col in cols for _ in range(2)] + cols_list = [col for col in self.lhs for _ in range(2)] vals_list = [val for val in self.rhs for _ in range(2)] cols_iter = iter(cols_list) vals_iter = iter(vals_list) @@ -137,10 +146,9 @@ class TupleLessThan(TupleLookupMixin, LessThan): def as_oracle(self, compiler, connection): # e.g.: (a, b, c) < (x, y, z) as SQL: # WHERE a < x OR (a = x AND (b < y OR (b = y AND c < z))) - cols = self.lhs.get_cols() lookups = itertools.cycle([LessThan, Exact]) connectors = itertools.cycle([OR, AND]) - cols_list = [col for col in cols for _ in range(2)] + cols_list = [col for col in self.lhs for _ in range(2)] vals_list = [val for val in self.rhs for _ in range(2)] cols_iter = iter(cols_list[:-1]) vals_iter = iter(vals_list[:-1]) @@ -164,10 +172,9 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual): def as_oracle(self, compiler, connection): # e.g.: (a, b, c) <= (x, y, z) as SQL: # WHERE a < x OR (a = x AND (b < y OR (b = y AND (c < z OR c = z)))) - cols = self.lhs.get_cols() lookups = itertools.cycle([LessThan, Exact]) connectors = itertools.cycle([OR, AND]) - cols_list = [col for col in cols for _ in range(2)] + cols_list = [col for col in self.lhs for _ in range(2)] vals_list = [val for val in self.rhs for _ in range(2)] cols_iter = iter(cols_list) vals_iter = iter(vals_list) @@ -188,8 +195,9 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual): class TupleIn(TupleLookupMixin, In): - def check_tuple_lookup(self): + def get_prep_lookup(self): self.check_rhs_elements_length_equals_lhs_length() + return super(TupleLookupMixin, self).get_prep_lookup() def check_rhs_elements_length_equals_lhs_length(self): len_lhs = len(self.lhs) @@ -199,37 +207,40 @@ class TupleIn(TupleLookupMixin, In): f"must have {len_lhs} elements each" ) - def as_sql(self, compiler, connection): - if not self.rhs: + def process_rhs(self, compiler, connection): + rhs = self.rhs + if not rhs: raise EmptyResultSet # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL: # WHERE (a, b, c) IN ((x1, y1, z1), (x2, y2, z2)) - rhs = [] - for vals in self.rhs: - rhs.append( + result = [] + lhs = self.lhs + + for vals in rhs: + result.append( Tuple( *[ Value(val, output_field=col.output_field) - for col, val in zip(self.lhs, vals) + for col, val in zip(lhs, vals) ] ) ) - lookup = In(Tuple(self.lhs), Tuple(*rhs)) - return lookup.as_sql(compiler, connection) + return Tuple(*result).as_sql(compiler, connection) def as_sqlite(self, compiler, connection): - if not self.rhs: + rhs = self.rhs + if not rhs: raise EmptyResultSet # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL: # WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2) root = WhereNode([], connector=OR) - cols = self.lhs.get_cols() + lhs = self.lhs - for vals in self.rhs: - lookups = [Exact(col, val) for col, val in zip(cols, vals)] + for vals in rhs: + lookups = [Exact(col, val) for col, val in zip(lhs, vals)] root.children.append(WhereNode(lookups, connector=AND)) return root.as_sql(compiler, connection) diff --git a/tests/foreign_object/test_tuple_lookups.py b/tests/foreign_object/test_tuple_lookups.py index 2742d6e93d..e2561676f3 100644 --- a/tests/foreign_object/test_tuple_lookups.py +++ b/tests/foreign_object/test_tuple_lookups.py @@ -1,6 +1,16 @@ import unittest from django.db import NotSupportedError, connection +from django.db.models import F +from django.db.models.fields.tuple_lookups import ( + TupleExact, + TupleGreaterThan, + TupleGreaterThanOrEqual, + TupleIn, + TupleIsNull, + TupleLessThan, + TupleLessThanOrEqual, +) from django.test import TestCase from .models import Contact, Customer @@ -32,10 +42,25 @@ class TupleLookupsTests(TestCase): ) for customer, contacts in test_cases: - with self.subTest(customer=customer, contacts=contacts): + with self.subTest( + "filter(customer=customer)", + customer=customer, + contacts=contacts, + ): self.assertSequenceEqual( Contact.objects.filter(customer=customer).order_by("id"), contacts ) + with self.subTest( + "filter(TupleExact)", + customer=customer, + contacts=contacts, + ): + lhs = (F("customer_code"), F("company_code")) + rhs = (customer.customer_id, customer.company) + lookup = TupleExact(lhs, rhs) + self.assertSequenceEqual( + Contact.objects.filter(lookup).order_by("id"), contacts + ) def test_exact_subquery(self): with self.assertRaisesMessage( @@ -71,11 +96,26 @@ class TupleLookupsTests(TestCase): ((cust_1, cust_2, cust_3, cust_4, cust_5), (c1, c2, c3, c4, c5, c6)), ) - for contacts, customers in test_cases: - with self.subTest(contacts=contacts, customers=customers): + for customers, contacts in test_cases: + with self.subTest( + "filter(customer__in=customers)", + customers=customers, + contacts=contacts, + ): self.assertSequenceEqual( - Contact.objects.filter(customer__in=contacts).order_by("id"), - customers, + Contact.objects.filter(customer__in=customers).order_by("id"), + contacts, + ) + with self.subTest( + "filter(TupleIn)", + customers=customers, + contacts=contacts, + ): + lhs = (F("customer_code"), F("company_code")) + rhs = [(c.customer_id, c.company) for c in customers] + lookup = TupleIn(lhs, rhs) + self.assertSequenceEqual( + Contact.objects.filter(lookup).order_by("id"), contacts ) @unittest.skipIf( @@ -107,11 +147,26 @@ class TupleLookupsTests(TestCase): ) for customer, contacts in test_cases: - with self.subTest(customer=customer, contacts=contacts): + with self.subTest( + "filter(customer__lt=customer)", + customer=customer, + contacts=contacts, + ): self.assertSequenceEqual( Contact.objects.filter(customer__lt=customer).order_by("id"), contacts, ) + with self.subTest( + "filter(TupleLessThan)", + customer=customer, + contacts=contacts, + ): + lhs = (F("customer_code"), F("company_code")) + rhs = (customer.customer_id, customer.company) + lookup = TupleLessThan(lhs, rhs) + self.assertSequenceEqual( + Contact.objects.filter(lookup).order_by("id"), contacts + ) def test_lt_subquery(self): with self.assertRaisesMessage( @@ -140,11 +195,26 @@ class TupleLookupsTests(TestCase): ) for customer, contacts in test_cases: - with self.subTest(customer=customer, contacts=contacts): + with self.subTest( + "filter(customer__lte=customer)", + customer=customer, + contacts=contacts, + ): self.assertSequenceEqual( Contact.objects.filter(customer__lte=customer).order_by("id"), contacts, ) + with self.subTest( + "filter(TupleLessThanOrEqual)", + customer=customer, + contacts=contacts, + ): + lhs = (F("customer_code"), F("company_code")) + rhs = (customer.customer_id, customer.company) + lookup = TupleLessThanOrEqual(lhs, rhs) + self.assertSequenceEqual( + Contact.objects.filter(lookup).order_by("id"), contacts + ) def test_lte_subquery(self): with self.assertRaisesMessage( @@ -165,11 +235,26 @@ class TupleLookupsTests(TestCase): ) for customer, contacts in test_cases: - with self.subTest(customer=customer, contacts=contacts): + with self.subTest( + "filter(customer__gt=customer)", + customer=customer, + contacts=contacts, + ): self.assertSequenceEqual( Contact.objects.filter(customer__gt=customer).order_by("id"), contacts, ) + with self.subTest( + "filter(TupleGreaterThan)", + customer=customer, + contacts=contacts, + ): + lhs = (F("customer_code"), F("company_code")) + rhs = (customer.customer_id, customer.company) + lookup = TupleGreaterThan(lhs, rhs) + self.assertSequenceEqual( + Contact.objects.filter(lookup).order_by("id"), contacts + ) def test_gt_subquery(self): with self.assertRaisesMessage( @@ -198,11 +283,26 @@ class TupleLookupsTests(TestCase): ) for customer, contacts in test_cases: - with self.subTest(customer=customer, contacts=contacts): + with self.subTest( + "filter(customer__gte=customer)", + customer=customer, + contacts=contacts, + ): self.assertSequenceEqual( Contact.objects.filter(customer__gte=customer).order_by("pk"), contacts, ) + with self.subTest( + "filter(TupleGreaterThanOrEqual)", + customer=customer, + contacts=contacts, + ): + lhs = (F("customer_code"), F("company_code")) + rhs = (customer.customer_id, customer.company) + lookup = TupleGreaterThanOrEqual(lhs, rhs) + self.assertSequenceEqual( + Contact.objects.filter(lookup).order_by("id"), contacts + ) def test_gte_subquery(self): with self.assertRaisesMessage( @@ -214,22 +314,38 @@ class TupleLookupsTests(TestCase): ) def test_isnull(self): - with self.subTest("customer__isnull=True"): + contacts = ( + self.contact_1, + self.contact_2, + self.contact_3, + self.contact_4, + self.contact_5, + self.contact_6, + ) + + with self.subTest("filter(customer__isnull=True)"): self.assertSequenceEqual( Contact.objects.filter(customer__isnull=True).order_by("id"), (), ) - with self.subTest("customer__isnull=False"): + with self.subTest("filter(TupleIsNull(True))"): + lhs = (F("customer_code"), F("company_code")) + lookup = TupleIsNull(lhs, True) + self.assertSequenceEqual( + Contact.objects.filter(lookup).order_by("id"), + (), + ) + with self.subTest("filter(customer__isnull=False)"): self.assertSequenceEqual( Contact.objects.filter(customer__isnull=False).order_by("id"), - ( - self.contact_1, - self.contact_2, - self.contact_3, - self.contact_4, - self.contact_5, - self.contact_6, - ), + contacts, + ) + with self.subTest("filter(TupleIsNull(False))"): + lhs = (F("customer_code"), F("company_code")) + lookup = TupleIsNull(lhs, False) + self.assertSequenceEqual( + Contact.objects.filter(lookup).order_by("id"), + contacts, ) def test_isnull_subquery(self):