mirror of
				https://github.com/django/django.git
				synced 2025-10-26 07:06:08 +00:00 
			
		
		
		
	Refs #373 -- Added TupleIn subqueries.
This commit is contained in:
		
				
					committed by
					
						 Sarah Boyce
						Sarah Boyce
					
				
			
			
				
	
			
			
			
						parent
						
							611bf6c2e2
						
					
				
				
					commit
					f7601aed51
				
			| @@ -12,6 +12,7 @@ from django.db.models.lookups import ( | |||||||
|     LessThan, |     LessThan, | ||||||
|     LessThanOrEqual, |     LessThanOrEqual, | ||||||
| ) | ) | ||||||
|  | from django.db.models.sql import Query | ||||||
| from django.db.models.sql.where import AND, OR, WhereNode | from django.db.models.sql.where import AND, OR, WhereNode | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -211,9 +212,14 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual): | |||||||
|  |  | ||||||
| class TupleIn(TupleLookupMixin, In): | class TupleIn(TupleLookupMixin, In): | ||||||
|     def get_prep_lookup(self): |     def get_prep_lookup(self): | ||||||
|  |         if self.rhs_is_direct_value(): | ||||||
|             self.check_rhs_is_tuple_or_list() |             self.check_rhs_is_tuple_or_list() | ||||||
|             self.check_rhs_is_collection_of_tuples_or_lists() |             self.check_rhs_is_collection_of_tuples_or_lists() | ||||||
|             self.check_rhs_elements_length_equals_lhs_length() |             self.check_rhs_elements_length_equals_lhs_length() | ||||||
|  |         else: | ||||||
|  |             self.check_rhs_is_query() | ||||||
|  |             self.check_rhs_select_length_equals_lhs_length() | ||||||
|  |  | ||||||
|         return self.rhs  # skip checks from mixin |         return self.rhs  # skip checks from mixin | ||||||
|  |  | ||||||
|     def check_rhs_is_collection_of_tuples_or_lists(self): |     def check_rhs_is_collection_of_tuples_or_lists(self): | ||||||
| @@ -233,6 +239,25 @@ class TupleIn(TupleLookupMixin, In): | |||||||
|                 f"must have {len_lhs} elements each" |                 f"must have {len_lhs} elements each" | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |     def check_rhs_is_query(self): | ||||||
|  |         if not isinstance(self.rhs, Query): | ||||||
|  |             lhs_str = self.get_lhs_str() | ||||||
|  |             rhs_cls = self.rhs.__class__.__name__ | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"{self.lookup_name!r} subquery lookup of {lhs_str} " | ||||||
|  |                 f"must be a Query object (received {rhs_cls!r})" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def check_rhs_select_length_equals_lhs_length(self): | ||||||
|  |         len_rhs = len(self.rhs.select) | ||||||
|  |         len_lhs = len(self.lhs) | ||||||
|  |         if len_rhs != len_lhs: | ||||||
|  |             lhs_str = self.get_lhs_str() | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"{self.lookup_name!r} subquery lookup of {lhs_str} " | ||||||
|  |                 f"must have {len_lhs} fields (received {len_rhs})" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|     def process_rhs(self, compiler, connection): |     def process_rhs(self, compiler, connection): | ||||||
|         rhs = self.rhs |         rhs = self.rhs | ||||||
|         if not rhs: |         if not rhs: | ||||||
| @@ -255,10 +280,17 @@ class TupleIn(TupleLookupMixin, In): | |||||||
|  |  | ||||||
|         return Tuple(*result).as_sql(compiler, connection) |         return Tuple(*result).as_sql(compiler, connection) | ||||||
|  |  | ||||||
|  |     def as_sql(self, compiler, connection): | ||||||
|  |         if not self.rhs_is_direct_value(): | ||||||
|  |             return self.as_subquery(compiler, connection) | ||||||
|  |         return super().as_sql(compiler, connection) | ||||||
|  |  | ||||||
|     def as_sqlite(self, compiler, connection): |     def as_sqlite(self, compiler, connection): | ||||||
|         rhs = self.rhs |         rhs = self.rhs | ||||||
|         if not rhs: |         if not rhs: | ||||||
|             raise EmptyResultSet |             raise EmptyResultSet | ||||||
|  |         if not self.rhs_is_direct_value(): | ||||||
|  |             return self.as_subquery(compiler, connection) | ||||||
|  |  | ||||||
|         # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL: |         # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL: | ||||||
|         # WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2) |         # WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2) | ||||||
| @@ -271,6 +303,9 @@ class TupleIn(TupleLookupMixin, In): | |||||||
|  |  | ||||||
|         return root.as_sql(compiler, connection) |         return root.as_sql(compiler, connection) | ||||||
|  |  | ||||||
|  |     def as_subquery(self, compiler, connection): | ||||||
|  |         return compiler.compile(In(self.lhs, self.rhs)) | ||||||
|  |  | ||||||
|  |  | ||||||
| tuple_lookups = { | tuple_lookups = { | ||||||
|     "exact": TupleExact, |     "exact": TupleExact, | ||||||
|   | |||||||
| @@ -11,6 +11,7 @@ from django.db.models.fields.tuple_lookups import ( | |||||||
|     TupleLessThan, |     TupleLessThan, | ||||||
|     TupleLessThanOrEqual, |     TupleLessThanOrEqual, | ||||||
| ) | ) | ||||||
|  | from django.db.models.lookups import In | ||||||
| from django.test import TestCase, skipUnlessDBFeature | from django.test import TestCase, skipUnlessDBFeature | ||||||
|  |  | ||||||
| from .models import Contact, Customer | from .models import Contact, Customer | ||||||
| @@ -126,6 +127,46 @@ class TupleLookupsTests(TestCase): | |||||||
|             (self.contact_1, self.contact_2, self.contact_5), |             (self.contact_1, self.contact_2, self.contact_5), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_tuple_in_subquery_must_be_query(self): | ||||||
|  |         lhs = (F("customer_code"), F("company_code")) | ||||||
|  |         # If rhs is any non-Query object with an as_sql() function. | ||||||
|  |         rhs = In(F("customer_code"), [1, 2, 3]) | ||||||
|  |         with self.assertRaisesMessage( | ||||||
|  |             ValueError, | ||||||
|  |             "'in' subquery lookup of ('customer_code', 'company_code') " | ||||||
|  |             "must be a Query object (received 'In')", | ||||||
|  |         ): | ||||||
|  |             TupleIn(lhs, rhs) | ||||||
|  |  | ||||||
|  |     def test_tuple_in_subquery_must_have_2_fields(self): | ||||||
|  |         lhs = (F("customer_code"), F("company_code")) | ||||||
|  |         rhs = Customer.objects.values_list("customer_id").query | ||||||
|  |         with self.assertRaisesMessage( | ||||||
|  |             ValueError, | ||||||
|  |             "'in' subquery lookup of ('customer_code', 'company_code') " | ||||||
|  |             "must have 2 fields (received 1)", | ||||||
|  |         ): | ||||||
|  |             TupleIn(lhs, rhs) | ||||||
|  |  | ||||||
|  |     def test_tuple_in_subquery(self): | ||||||
|  |         customers = Customer.objects.values_list("customer_id", "company") | ||||||
|  |         test_cases = ( | ||||||
|  |             (self.customer_1, (self.contact_1, self.contact_2, self.contact_5)), | ||||||
|  |             (self.customer_2, (self.contact_3,)), | ||||||
|  |             (self.customer_3, (self.contact_4,)), | ||||||
|  |             (self.customer_4, ()), | ||||||
|  |             (self.customer_5, (self.contact_6,)), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         for customer, contacts in test_cases: | ||||||
|  |             lhs = (F("customer_code"), F("company_code")) | ||||||
|  |             rhs = customers.filter(id=customer.id).query | ||||||
|  |             lookup = TupleIn(lhs, rhs) | ||||||
|  |             qs = Contact.objects.filter(lookup).order_by("id") | ||||||
|  |  | ||||||
|  |             with self.subTest(customer=customer.id, query=str(qs.query)): | ||||||
|  |                 self.assertSequenceEqual(qs, contacts) | ||||||
|  |  | ||||||
|     def test_tuple_in_rhs_must_be_collection_of_tuples_or_lists(self): |     def test_tuple_in_rhs_must_be_collection_of_tuples_or_lists(self): | ||||||
|         test_cases = ( |         test_cases = ( | ||||||
|             (1, 2, 3), |             (1, 2, 3), | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user