mirror of
				https://github.com/django/django.git
				synced 2025-10-25 14:46:09 +00:00 
			
		
		
		
	Refs #373 -- Improved test coverage of tuple lookup checks.
This also removed unreachable checks.
This commit is contained in:
		
				
					committed by
					
						 Sarah Boyce
						Sarah Boyce
					
				
			
			
				
	
			
			
			
						parent
						
							2a4321ba23
						
					
				
				
					commit
					347ab72c02
				
			| @@ -1,7 +1,7 @@ | |||||||
| import itertools | import itertools | ||||||
|  |  | ||||||
| from django.core.exceptions import EmptyResultSet | from django.core.exceptions import EmptyResultSet | ||||||
| from django.db.models.expressions import ColPairs, Func, Value | from django.db.models.expressions import Func, Value | ||||||
| from django.db.models.lookups import ( | from django.db.models.lookups import ( | ||||||
|     Exact, |     Exact, | ||||||
|     GreaterThan, |     GreaterThan, | ||||||
| @@ -24,36 +24,14 @@ class TupleLookupMixin: | |||||||
|         return super().get_prep_lookup() |         return super().get_prep_lookup() | ||||||
|  |  | ||||||
|     def check_tuple_lookup(self): |     def check_tuple_lookup(self): | ||||||
|         assert isinstance(self.lhs, ColPairs) |  | ||||||
|         self.check_rhs_is_tuple_or_list() |  | ||||||
|         self.check_rhs_length_equals_lhs_length() |         self.check_rhs_length_equals_lhs_length() | ||||||
|  |  | ||||||
|     def check_rhs_is_tuple_or_list(self): |  | ||||||
|         if not isinstance(self.rhs, (tuple, list)): |  | ||||||
|             raise ValueError( |  | ||||||
|                 f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " |  | ||||||
|                 "must be a tuple or a list" |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|     def check_rhs_length_equals_lhs_length(self): |     def check_rhs_length_equals_lhs_length(self): | ||||||
|         if len(self.lhs) != len(self.rhs): |         len_lhs = len(self.lhs) | ||||||
|  |         if len_lhs != len(self.rhs): | ||||||
|             raise ValueError( |             raise ValueError( | ||||||
|                 f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " |                 f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " | ||||||
|                 f"must have {len(self.lhs)} elements" |                 f"must have {len_lhs} elements" | ||||||
|             ) |  | ||||||
|  |  | ||||||
|     def check_rhs_is_collection_of_tuples_or_lists(self): |  | ||||||
|         if not all(isinstance(vals, (tuple, list)) for vals in self.rhs): |  | ||||||
|             raise ValueError( |  | ||||||
|                 f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " |  | ||||||
|                 f"must be a collection of tuples or lists" |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|     def check_rhs_elements_length_equals_lhs_length(self): |  | ||||||
|         if not all(len(self.lhs) == len(vals) for vals in self.rhs): |  | ||||||
|             raise ValueError( |  | ||||||
|                 f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " |  | ||||||
|                 f"must have {len(self.lhs)} elements each" |  | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def as_sql(self, compiler, connection): |     def as_sql(self, compiler, connection): | ||||||
| @@ -192,11 +170,16 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual): | |||||||
|  |  | ||||||
| class TupleIn(TupleLookupMixin, In): | class TupleIn(TupleLookupMixin, In): | ||||||
|     def check_tuple_lookup(self): |     def check_tuple_lookup(self): | ||||||
|         assert isinstance(self.lhs, ColPairs) |  | ||||||
|         self.check_rhs_is_tuple_or_list() |  | ||||||
|         self.check_rhs_is_collection_of_tuples_or_lists() |  | ||||||
|         self.check_rhs_elements_length_equals_lhs_length() |         self.check_rhs_elements_length_equals_lhs_length() | ||||||
|  |  | ||||||
|  |     def check_rhs_elements_length_equals_lhs_length(self): | ||||||
|  |         len_lhs = len(self.lhs) | ||||||
|  |         if not all(len_lhs == len(vals) for vals in self.rhs): | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field " | ||||||
|  |                 f"must have {len_lhs} elements each" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|     def as_sql(self, compiler, connection): |     def as_sql(self, compiler, connection): | ||||||
|         if not self.rhs: |         if not self.rhs: | ||||||
|             raise EmptyResultSet |             raise EmptyResultSet | ||||||
|   | |||||||
| @@ -240,3 +240,28 @@ class TupleLookupsTests(TestCase): | |||||||
|             self.assertSequenceEqual( |             self.assertSequenceEqual( | ||||||
|                 Contact.objects.filter(customer__isnull=subquery).order_by("id"), () |                 Contact.objects.filter(customer__isnull=subquery).order_by("id"), () | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |     def test_lookup_errors(self): | ||||||
|  |         m_2_elements = "'%s' lookup of 'customer' field must have 2 elements" | ||||||
|  |         m_2_elements_each = "'in' lookup of 'customer' field must have 2 elements each" | ||||||
|  |         test_cases = ( | ||||||
|  |             ({"customer": 1}, m_2_elements % "exact"), | ||||||
|  |             ({"customer": (1, 2, 3)}, m_2_elements % "exact"), | ||||||
|  |             ({"customer__in": (1, 2, 3)}, m_2_elements_each), | ||||||
|  |             ({"customer__in": ("foo", "bar")}, m_2_elements_each), | ||||||
|  |             ({"customer__gt": 1}, m_2_elements % "gt"), | ||||||
|  |             ({"customer__gt": (1, 2, 3)}, m_2_elements % "gt"), | ||||||
|  |             ({"customer__gte": 1}, m_2_elements % "gte"), | ||||||
|  |             ({"customer__gte": (1, 2, 3)}, m_2_elements % "gte"), | ||||||
|  |             ({"customer__lt": 1}, m_2_elements % "lt"), | ||||||
|  |             ({"customer__lt": (1, 2, 3)}, m_2_elements % "lt"), | ||||||
|  |             ({"customer__lte": 1}, m_2_elements % "lte"), | ||||||
|  |             ({"customer__lte": (1, 2, 3)}, m_2_elements % "lte"), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         for kwargs, message in test_cases: | ||||||
|  |             with ( | ||||||
|  |                 self.subTest(kwargs=kwargs), | ||||||
|  |                 self.assertRaisesMessage(ValueError, message), | ||||||
|  |             ): | ||||||
|  |                 Contact.objects.get(**kwargs) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user